mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Improve tracing to propagate to backend workers correctly when using the API. Fix distributed tests.
This commit is contained in:
parent
75b63fce91
commit
f8fcfa6f08
@ -7,14 +7,19 @@ from asyncio import AbstractEventLoop
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from opentelemetry import context
|
||||||
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
|
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
|
from ..component_model.make_mutable import make_mutable
|
||||||
from ..distributed.server_stub import ServerStub
|
from ..distributed.server_stub import ServerStub
|
||||||
|
|
||||||
_server_stub_instance = ServerStub()
|
_server_stub_instance = ServerStub()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedComfyClient:
|
class EmbeddedComfyClient:
|
||||||
"""
|
"""
|
||||||
Embedded client for comfy executing prompts as a library.
|
Embedded client for comfy executing prompts as a library.
|
||||||
@ -110,30 +115,36 @@ class EmbeddedComfyClient:
|
|||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Queue Prompt")
|
||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
prompt: PromptDict | dict,
|
prompt: PromptDict | dict,
|
||||||
prompt_id: Optional[str] = None,
|
prompt_id: Optional[str] = None,
|
||||||
client_id: Optional[str] = None) -> dict:
|
client_id: Optional[str] = None) -> dict:
|
||||||
prompt_id = prompt_id or str(uuid.uuid4())
|
prompt_id = prompt_id or str(uuid.uuid4())
|
||||||
client_id = client_id or self._progress_handler.client_id or None
|
client_id = client_id or self._progress_handler.client_id or None
|
||||||
|
span_context = context.get_current()
|
||||||
|
|
||||||
def execute_prompt() -> dict:
|
def execute_prompt() -> dict:
|
||||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
spam: Span
|
||||||
prompt_mut = make_mutable(prompt)
|
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||||
validation_tuple = validate_prompt(prompt_mut)
|
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||||
|
prompt_mut = make_mutable(prompt)
|
||||||
|
validation_tuple = validate_prompt(prompt_mut)
|
||||||
|
|
||||||
prompt_executor: PromptExecutor = self._prompt_executor
|
prompt_executor: PromptExecutor = self._prompt_executor
|
||||||
|
|
||||||
if client_id is None:
|
if client_id is None:
|
||||||
prompt_executor.server = _server_stub_instance
|
prompt_executor.server = _server_stub_instance
|
||||||
else:
|
else:
|
||||||
prompt_executor.server = self._progress_handler
|
prompt_executor.server = self._progress_handler
|
||||||
|
|
||||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
execute_outputs=validation_tuple[2])
|
execute_outputs=validation_tuple[2])
|
||||||
if prompt_executor.success:
|
if prompt_executor.success:
|
||||||
return prompt_executor.outputs_ui
|
return prompt_executor.outputs_ui
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
|
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||||
|
span.record_exception(error)
|
||||||
|
|
||||||
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
||||||
|
|||||||
@ -9,22 +9,26 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from typing_extensions import TypedDict
|
|
||||||
import torch
|
|
||||||
import lazy_object_proxy
|
|
||||||
|
|
||||||
|
import lazy_object_proxy
|
||||||
|
import torch
|
||||||
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from .main_pre import tracer
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..nodes.package_typing import ExportedNodes
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
|
from ..nodes.package_typing import ExportedNodes
|
||||||
|
|
||||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||||
# various functions that are declared here. It should have been a context in the first place.
|
# various functions that are declared here. It should have been a context in the first place.
|
||||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||||
|
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
||||||
if extra_data is None:
|
if extra_data is None:
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
@ -139,6 +143,7 @@ def format_value(x):
|
|||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Recursive Execute")
|
||||||
def recursive_execute(server: ExecutorToClientProgress,
|
def recursive_execute(server: ExecutorToClientProgress,
|
||||||
prompt,
|
prompt,
|
||||||
outputs,
|
outputs,
|
||||||
@ -324,6 +329,10 @@ class PromptExecutor:
|
|||||||
self.server.send_sync(event, data, self.server.client_id)
|
self.server.send_sync(event, data, self.server.client_id)
|
||||||
|
|
||||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||||
|
current_span = get_current_span()
|
||||||
|
current_span.set_status(Status(StatusCode.ERROR))
|
||||||
|
current_span.record_exception(ex)
|
||||||
|
|
||||||
node_id = error["node_id"]
|
node_id = error["node_id"]
|
||||||
class_type = prompt[node_id]["class_type"]
|
class_type = prompt[node_id]["class_type"]
|
||||||
|
|
||||||
@ -430,7 +439,6 @@ class PromptExecutor:
|
|||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
# the actual SD code, instead it will report the node where the
|
# the actual SD code, instead it will report the node where the
|
||||||
# error was raised
|
# error was raised
|
||||||
# todo: if we're using a distributed queue, we must wrap the server instance to correctly communicate back to the client via the exchange
|
|
||||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
|
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
|
||||||
extra_data, executed, prompt_id, self.outputs_ui,
|
extra_data, executed, prompt_id, self.outputs_ui,
|
||||||
self.object_storage)
|
self.object_storage)
|
||||||
@ -689,6 +697,7 @@ class ValidationErrorDict(TypedDict):
|
|||||||
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
|
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
|
||||||
|
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Validate Prompt")
|
||||||
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
|
|||||||
@ -13,17 +13,16 @@ import uuid
|
|||||||
from asyncio import Future, AbstractEventLoop
|
from asyncio import Future, AbstractEventLoop
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from posixpath import join as urljoin
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from urllib.parse import quote, urlencode
|
from urllib.parse import quote, urlencode
|
||||||
from posixpath import join as urljoin
|
|
||||||
|
|
||||||
from can_ada import URL, parse as urlparse
|
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from can_ada import URL, parse as urlparse
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
@ -36,12 +35,13 @@ from ..cli_args import args
|
|||||||
from ..client.client_types import Output, FileOutput
|
from ..client.client_types import Output, FileOutput
|
||||||
from ..cmd import execution
|
from ..cmd import execution
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
|
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
|
||||||
from ..digest import digest
|
from ..digest import digest
|
||||||
from ..nodes.package_typing import ExportedNodes
|
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
|
from ..nodes.package_typing import ExportedNodes
|
||||||
|
|
||||||
|
|
||||||
class HeuristicPath(NamedTuple):
|
class HeuristicPath(NamedTuple):
|
||||||
@ -95,7 +95,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
# todo: this is probably read by custom nodes elsewhere
|
# todo: this is probably read by custom nodes elsewhere
|
||||||
self.supports: List[str] = ["custom_nodes_from_web"]
|
self.supports: List[str] = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue: execution.AbstractPromptQueue | None = None
|
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
|
||||||
self.loop: AbstractEventLoop = loop
|
self.loop: AbstractEventLoop = loop
|
||||||
self.messages: asyncio.Queue = asyncio.Queue()
|
self.messages: asyncio.Queue = asyncio.Queue()
|
||||||
self.number: int = 0
|
self.number: int = 0
|
||||||
@ -599,22 +599,27 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
||||||
|
|
||||||
# convert a valid prompt to the queue tuple this expects
|
# convert a valid prompt to the queue tuple this expects
|
||||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
|
||||||
number = self.number
|
number = self.number
|
||||||
self.number += 1
|
self.number += 1
|
||||||
self.prompt_queue.put(
|
|
||||||
QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
|
|
||||||
completed=completed))
|
|
||||||
|
|
||||||
try:
|
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||||
await completed
|
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
|
||||||
except Exception as ex:
|
|
||||||
return web.Response(body=str(ex), status=503)
|
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
|
||||||
# expect a single image
|
# this enables span propagation seamlessly
|
||||||
result: TaskInvocation | dict = completed.result()
|
result = await self.prompt_queue.put_async(item)
|
||||||
|
if result is None:
|
||||||
|
return web.Response(body="the queue is shutting down", status=503)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.prompt_queue.put(item)
|
||||||
|
await completed
|
||||||
|
except Exception as ex:
|
||||||
|
return web.Response(body=str(ex), status=503)
|
||||||
|
# expect a single image
|
||||||
|
result: TaskInvocation | dict = completed.result()
|
||||||
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
|
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
|
||||||
# find images and read them
|
# find images and read them
|
||||||
|
|
||||||
output_images: List[FileOutput] = []
|
output_images: List[FileOutput] = []
|
||||||
for node_id, node in outputs_dict.items():
|
for node_id, node in outputs_dict.items():
|
||||||
images: List[FileOutput] = []
|
images: List[FileOutput] = []
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus
|
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation
|
||||||
|
|
||||||
|
|
||||||
class AbstractPromptQueue(metaclass=ABCMeta):
|
class AbstractPromptQueue(metaclass=ABCMeta):
|
||||||
@ -118,3 +118,13 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAbstractPromptQueue(AbstractPromptQueue):
|
||||||
|
@abstractmethod
|
||||||
|
async def put_async(self, queue_item) -> TaskInvocation | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
|
||||||
|
pass
|
||||||
@ -18,13 +18,14 @@ from .distributed_types import RpcRequest, RpcReply
|
|||||||
from .history import History
|
from .history import History
|
||||||
from .server_stub import ServerStub
|
from .server_stub import ServerStub
|
||||||
from ..auth.permissions import jwt_decode
|
from ..auth.permissions import jwt_decode
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from ..cmd.server import PromptServer
|
from ..cmd.server import PromptServer
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||||
|
|
||||||
|
|
||||||
class DistributedPromptQueue(AbstractPromptQueue):
|
class DistributedPromptQueue(AsyncAbstractPromptQueue):
|
||||||
"""
|
"""
|
||||||
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
|
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
|
||||||
"""
|
"""
|
||||||
@ -39,12 +40,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
|
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
|
||||||
self._caller_server.send_sync(event, data, sid=sid)
|
self._caller_server.send_sync(event, data, sid=sid)
|
||||||
|
|
||||||
async def put_async(self, queue_item: QueueItem):
|
@tracer.start_as_current_span("Put Async")
|
||||||
|
async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None:
|
||||||
assert self._is_caller
|
assert self._is_caller
|
||||||
assert self._rpc is not None
|
assert self._rpc is not None
|
||||||
|
|
||||||
if self._closing:
|
if self._closing:
|
||||||
return
|
return None
|
||||||
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||||
if self._caller_server is not None:
|
if self._caller_server is not None:
|
||||||
self._caller_server.queue_updated()
|
self._caller_server.queue_updated()
|
||||||
@ -88,9 +90,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
# if we have a completer, propoagate the exception to it
|
# if we have a completer, propoagate the exception to it
|
||||||
if queue_item.completed is not None:
|
if queue_item.completed is not None:
|
||||||
queue_item.completed.set_exception(e)
|
queue_item.completed.set_exception(e)
|
||||||
else:
|
raise e
|
||||||
# otherwise, this should raise in the event loop, which I suppose isn't handled
|
|
||||||
raise e
|
|
||||||
finally:
|
finally:
|
||||||
self._caller_local_in_progress.pop(queue_item.prompt_id)
|
self._caller_local_in_progress.pop(queue_item.prompt_id)
|
||||||
if self._caller_server is not None:
|
if self._caller_server is not None:
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from aiormq import AMQPConnectionError
|
|||||||
from .distributed_progress import DistributedExecutorToClientProgress
|
from .distributed_progress import DistributedExecutorToClientProgress
|
||||||
from .distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.queue_types import ExecutionStatus
|
from ..component_model.queue_types import ExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ class DistributedPromptWorker:
|
|||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._embedded_comfy_client = embedded_comfy_client
|
self._embedded_comfy_client = embedded_comfy_client
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Do Work Item")
|
||||||
async def _do_work_item(self, request: dict) -> dict:
|
async def _do_work_item(self, request: dict) -> dict:
|
||||||
await self.on_will_complete_work_item(request)
|
await self.on_will_complete_work_item(request)
|
||||||
try:
|
try:
|
||||||
@ -55,6 +57,7 @@ class DistributedPromptWorker:
|
|||||||
await self.on_did_complete_work_item(request_obj, reply)
|
await self.on_did_complete_work_item(request_obj, reply)
|
||||||
return asdict(reply)
|
return asdict(reply)
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Initialize Prompt Worker")
|
||||||
async def init(self):
|
async def init(self):
|
||||||
await self._exit_stack.__aenter__()
|
await self._exit_stack.__aenter__()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -9,8 +9,11 @@ import types
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from importlib.metadata import entry_points
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
|
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
|
|
||||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||||
|
|
||||||
@ -46,29 +49,33 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
else:
|
else:
|
||||||
# Iterate through all the submodules
|
# Iterate through all the submodules
|
||||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||||
full_name = module.__name__ + "." + name
|
span: Span
|
||||||
time_before = time.perf_counter()
|
with tracer.start_as_current_span("Load Node") as span:
|
||||||
success = True
|
full_name = module.__name__ + "." + name
|
||||||
|
time_before = time.perf_counter()
|
||||||
if full_name.endswith(".disabled"):
|
success = True
|
||||||
continue
|
span.set_attribute("full_name", full_name)
|
||||||
try:
|
if full_name.endswith(".disabled"):
|
||||||
submodule = importlib.import_module(full_name)
|
continue
|
||||||
# Recursively call the function if it's a package
|
try:
|
||||||
exported_nodes.update(
|
submodule = importlib.import_module(full_name)
|
||||||
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times,
|
# Recursively call the function if it's a package
|
||||||
depth=depth - 1))
|
new_nodes = _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, depth=depth - 1)
|
||||||
except KeyboardInterrupt as interrupted:
|
span.set_attribute("new_nodes.length", len(new_nodes))
|
||||||
raise interrupted
|
exported_nodes.update(new_nodes)
|
||||||
except Exception as x:
|
except KeyboardInterrupt as interrupted:
|
||||||
if isinstance(x, AttributeError):
|
raise interrupted
|
||||||
potential_path_error: AttributeError = x
|
except Exception as x:
|
||||||
if potential_path_error.name == '__path__':
|
if isinstance(x, AttributeError):
|
||||||
continue
|
potential_path_error: AttributeError = x
|
||||||
logging.error(f"{full_name} import failed", exc_info=x)
|
if potential_path_error.name == '__path__':
|
||||||
success = False
|
continue
|
||||||
exceptions.append(x)
|
logging.error(f"{full_name} import failed", exc_info=x)
|
||||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
success = False
|
||||||
|
exceptions.append(x)
|
||||||
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
|
span.record_exception(x)
|
||||||
|
timings.append((time.perf_counter() - time_before, full_name, success))
|
||||||
|
|
||||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
||||||
for (duration, module_name, success) in sorted(timings):
|
for (duration, module_name, success) in sorted(timings):
|
||||||
@ -81,6 +88,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
return exported_nodes
|
return exported_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||||
# now actually import the nodes, to improve control of node loading order
|
# now actually import the nodes, to improve control of node loading order
|
||||||
from comfy_extras import nodes as comfy_extras_nodes
|
from comfy_extras import nodes as comfy_extras_nodes
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -17,7 +18,7 @@ from .component_model.executor_types import ExecutorToClientProgress
|
|||||||
from .component_model.queue_types import BinaryEventTypes
|
from .component_model.queue_types import BinaryEventTypes
|
||||||
|
|
||||||
PROGRESS_BAR_ENABLED = True
|
PROGRESS_BAR_ENABLED = True
|
||||||
PROGRESS_BAR_HOOK = None
|
_progress_bar_hook = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -479,7 +480,7 @@ def hijack_progress(server: ExecutorToClientProgress):
|
|||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||||
|
|
||||||
set_progress_bar_global_hook(hook)
|
_progress_bar_hook.hook = hook
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_enabled(enabled):
|
def set_progress_bar_enabled(enabled):
|
||||||
@ -491,11 +492,6 @@ def get_progress_bar_enabled() -> bool:
|
|||||||
return PROGRESS_BAR_ENABLED
|
return PROGRESS_BAR_ENABLED
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_global_hook(function):
|
|
||||||
global PROGRESS_BAR_HOOK
|
|
||||||
PROGRESS_BAR_HOOK = function
|
|
||||||
|
|
||||||
|
|
||||||
class _DisabledProgressBar:
|
class _DisabledProgressBar:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -509,10 +505,10 @@ class _DisabledProgressBar:
|
|||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self, total: float):
|
def __init__(self, total: float):
|
||||||
global PROGRESS_BAR_HOOK
|
global _progress_bar_hook
|
||||||
self.total: float = total
|
self.total: float = total
|
||||||
self.current: float = 0.0
|
self.current: float = 0.0
|
||||||
self.hook = PROGRESS_BAR_HOOK
|
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
|
|||||||
@ -64,7 +64,7 @@ def comfy_background_server(args_pytest):
|
|||||||
if not success:
|
if not success:
|
||||||
raise Exception("Failed to start background server")
|
raise Exception("Failed to start background server")
|
||||||
yield
|
yield
|
||||||
p.kill()
|
p.terminate()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,27 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import subprocess
|
|
||||||
import socket
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.distributed.server_stub import ServerStub
|
|
||||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||||
from comfy.component_model.make_mutable import make_mutable
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
os.environ["TC_HOST"] = "localhost"
|
os.environ["TC_HOST"] = "localhost"
|
||||||
@ -55,8 +54,7 @@ async def test_basic_queue_worker() -> None:
|
|||||||
# this unfortunately does a bunch of initialization on the test thread
|
# this unfortunately does a bunch of initialization on the test thread
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
# now submit some jobs
|
# now submit some jobs
|
||||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
|
||||||
await distributed_queue.init()
|
await distributed_queue.init()
|
||||||
queue_item = create_test_prompt()
|
queue_item = create_test_prompt()
|
||||||
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
||||||
@ -74,10 +72,8 @@ async def test_distributed_prompt_queues_same_process():
|
|||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||||
connection_uri=connection_uri) as frontend:
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
|
|
||||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
|
||||||
test_prompt = create_test_prompt()
|
test_prompt = create_test_prompt()
|
||||||
test_prompt.completed = asyncio.Future()
|
test_prompt.completed = asyncio.Future()
|
||||||
|
|
||||||
@ -117,6 +113,7 @@ async def test_frontend_backend_workers():
|
|||||||
frontend_command = [
|
frontend_command = [
|
||||||
"comfyui",
|
"comfyui",
|
||||||
"--listen=0.0.0.0",
|
"--listen=0.0.0.0",
|
||||||
|
"--port=9001",
|
||||||
"--cpu",
|
"--cpu",
|
||||||
"--distributed-queue-frontend",
|
"--distributed-queue-frontend",
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
@ -125,6 +122,7 @@ async def test_frontend_backend_workers():
|
|||||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
backend_command = [
|
backend_command = [
|
||||||
"comfyui-worker",
|
"comfyui-worker",
|
||||||
|
"--port=9002",
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -145,7 +143,7 @@ async def test_frontend_backend_workers():
|
|||||||
client = AsyncRemoteComfyClient(server_address=server_address)
|
client = AsyncRemoteComfyClient(server_address=server_address)
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
png_image_bytes = await client.queue_prompt(prompt)
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
assert len(png_image_bytes) > 1000
|
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
||||||
finally:
|
finally:
|
||||||
for process in processes_to_close:
|
for process in processes_to_close:
|
||||||
process.terminate()
|
process.terminate()
|
||||||
|
|||||||
@ -24,8 +24,7 @@ async def test_cuda_memory_usage():
|
|||||||
assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow"
|
assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow"
|
||||||
|
|
||||||
ending_memory = torch.cuda.memory_allocated(device)
|
ending_memory = torch.cuda.memory_allocated(device)
|
||||||
assert abs(
|
assert abs(ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
|
||||||
ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user