mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Improve tests and distributed error notifications
- Tests now perform faster - Tests will run on supported GPU platforms - Configuration has known issues related to setting up a working directory for an embedded client - Introduce a Skeletonize node that solves many problems with Canny - Improve behavior of exception reporting
This commit is contained in:
parent
dbc2a4ba29
commit
95d47276e9
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -41,10 +41,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
||||||
pytest -v tests/unit
|
pytest -v tests/unit
|
||||||
- name: Run lora workflow
|
- name: Run all tests
|
||||||
run: |
|
run: |
|
||||||
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
||||||
pytest -v tests/workflows
|
pytest -v tests
|
||||||
- name: Lint for errors
|
- name: Lint for errors
|
||||||
run: |
|
run: |
|
||||||
pylint comfy
|
pylint comfy
|
||||||
@ -293,7 +293,7 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
hs = []
|
hs = []
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0], "There may be a mismatch between the ControlNet and Diffusion models being used"
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x
|
h = x
|
||||||
|
|||||||
@ -11,11 +11,11 @@ from aiohttp import WSMessage, ClientResponse
|
|||||||
from typing_extensions import Dict
|
from typing_extensions import Dict
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.schemas import immutabledict
|
|
||||||
from ..api.components.schema.prompt import PromptDict
|
|
||||||
from ..api.api_client import JSONEncoder
|
from ..api.api_client import JSONEncoder
|
||||||
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..api.components.schema.prompt_request import PromptRequest
|
from ..api.components.schema.prompt_request import PromptRequest
|
||||||
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
||||||
|
from ..api.schemas import immutabledict
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +34,15 @@ class AsyncRemoteComfyClient:
|
|||||||
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
|
||||||
|
async def len_queue(self) -> int:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
exec_info_dict = await response.json()
|
||||||
|
return exec_info_dict["exec_info"]["queue_remaining"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
Calls the API to queue a prompt.
|
Calls the API to queue a prompt.
|
||||||
@ -71,7 +80,7 @@ class AsyncRemoteComfyClient:
|
|||||||
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||||
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
|
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
|
||||||
|
|
||||||
if response.status == 200:
|
if 200 <= response.status < 400:
|
||||||
return await response.read()
|
return await response.read()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import get_event_loop
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -57,13 +57,9 @@ class EmbeddedComfyClient:
|
|||||||
In order to use this in blocking methods, learn more about asyncio online.
|
In order to use this in blocking methods, learn more about asyncio online.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None,
|
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1):
|
||||||
progress_handler: Optional[ExecutorToClientProgress] = None,
|
|
||||||
loop: Optional[AbstractEventLoop] = None,
|
|
||||||
max_workers: int = 1):
|
|
||||||
self._progress_handler = progress_handler or ServerStub()
|
self._progress_handler = progress_handler or ServerStub()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
# we don't want to import the executor yet
|
# we don't want to import the executor yet
|
||||||
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
||||||
@ -93,7 +89,7 @@ class EmbeddedComfyClient:
|
|||||||
while self._executor._work_queue.qsize() > 0:
|
while self._executor._work_queue.qsize() > 0:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, cleanup)
|
await get_event_loop().run_in_executor(self._executor, cleanup)
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
@ -112,8 +108,9 @@ class EmbeddedComfyClient:
|
|||||||
from ..cmd.execution import PromptExecutor
|
from ..cmd.execution import PromptExecutor
|
||||||
|
|
||||||
self._prompt_executor = PromptExecutor(self._progress_handler)
|
self._prompt_executor = PromptExecutor(self._progress_handler)
|
||||||
|
self._prompt_executor.raise_exceptions = True
|
||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
|
||||||
|
|
||||||
@tracer.start_as_current_span("Queue Prompt")
|
@tracer.start_as_current_span("Queue Prompt")
|
||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
@ -128,29 +125,26 @@ class EmbeddedComfyClient:
|
|||||||
spam: Span
|
spam: Span
|
||||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||||
prompt_mut = make_mutable(prompt)
|
try:
|
||||||
validation_tuple = validate_prompt(prompt_mut)
|
prompt_mut = make_mutable(prompt)
|
||||||
if not validation_tuple[0]:
|
validation_tuple = validate_prompt(prompt_mut)
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
if not validation_tuple[0]:
|
||||||
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
|
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
|
||||||
error = ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
|
raise ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
|
||||||
span.record_exception(error)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
prompt_executor: PromptExecutor = self._prompt_executor
|
prompt_executor: PromptExecutor = self._prompt_executor
|
||||||
|
|
||||||
if client_id is None:
|
if client_id is None:
|
||||||
prompt_executor.server = _server_stub_instance
|
prompt_executor.server = _server_stub_instance
|
||||||
else:
|
else:
|
||||||
prompt_executor.server = self._progress_handler
|
prompt_executor.server = self._progress_handler
|
||||||
|
|
||||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
execute_outputs=validation_tuple[2])
|
execute_outputs=validation_tuple[2])
|
||||||
if prompt_executor.success:
|
|
||||||
return prompt_executor.outputs_ui
|
return prompt_executor.outputs_ui
|
||||||
else:
|
except Exception as exc_info:
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
span.record_exception(exc_info)
|
||||||
span.record_exception(error)
|
raise exc_info
|
||||||
|
|
||||||
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
return await get_event_loop().run_in_executor(self._executor, execute_prompt)
|
||||||
|
|||||||
@ -344,6 +344,7 @@ class PromptExecutor:
|
|||||||
def __init__(self, server: ExecutorToClientProgress):
|
def __init__(self, server: ExecutorToClientProgress):
|
||||||
self.success = None
|
self.success = None
|
||||||
self.server = server
|
self.server = server
|
||||||
|
self.raise_exceptions = False
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -404,6 +405,9 @@ class PromptExecutor:
|
|||||||
d = self.outputs.pop(o)
|
d = self.outputs.pop(o)
|
||||||
del d
|
del d
|
||||||
|
|
||||||
|
if ex is not None and self.raise_exceptions:
|
||||||
|
raise ex
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||||
with new_execution_context(ExecutionContext(self.server)):
|
with new_execution_context(ExecutionContext(self.server)):
|
||||||
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||||
|
|||||||
@ -61,8 +61,8 @@ def _create_tracer():
|
|||||||
|
|
||||||
if has_endpoint:
|
if has_endpoint:
|
||||||
otlp_exporter = OTLPSpanExporter()
|
otlp_exporter = OTLPSpanExporter()
|
||||||
elif is_debugging:
|
# elif is_debugging:
|
||||||
otlp_exporter = ConsoleSpanExporter()
|
# otlp_exporter = ConsoleSpanExporter("comfyui")
|
||||||
else:
|
else:
|
||||||
otlp_exporter = SpanExporter()
|
otlp_exporter = SpanExporter()
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import logging
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import Future, AbstractEventLoop
|
from asyncio import Future, AbstractEventLoop
|
||||||
@ -19,14 +18,15 @@ from urllib.parse import quote, urlencode
|
|||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import sys
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
from .. import interruption
|
|
||||||
from .latent_preview_image_encoding import encode_preview_image
|
from .latent_preview_image_encoding import encode_preview_image
|
||||||
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..app.user_manager import UserManager
|
from ..app.user_manager import UserManager
|
||||||
@ -35,7 +35,7 @@ from ..client.client_types import FileOutput
|
|||||||
from ..cmd import execution
|
from ..cmd import execution
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
||||||
@ -778,8 +778,12 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
self.messages.put_nowait, (event, data, sid))
|
self.messages.put_nowait, (event, data, sid))
|
||||||
|
|
||||||
def queue_updated(self):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
self.send_sync("status", {"status": self.get_queue_info()})
|
if queue_remaining is None:
|
||||||
|
status = {"status": self.get_queue_info()}
|
||||||
|
else:
|
||||||
|
status = StatusMessage(status=QueueInfo(exec_info=ExecInfo(queue_remaining=queue_remaining)))
|
||||||
|
self.send_sync("status", status)
|
||||||
|
|
||||||
async def publish_loop(self):
|
async def publish_loop(self):
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations # for Python 3.7-3.9
|
from __future__ import annotations # for Python 3.7-3.9
|
||||||
|
|
||||||
|
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
|
||||||
|
|
||||||
from .queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def queue_updated(self):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Indicates that the local client's queue has been updated
|
Indicates that the local client's queue has been updated
|
||||||
:return:
|
:return:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from typing import Optional, Dict, Any
|
|||||||
from aio_pika.patterns import RPC
|
from aio_pika.patterns import RPC
|
||||||
|
|
||||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
||||||
UnencodedPreviewImageMessage
|
UnencodedPreviewImageMessage, StatusMessage, QueueInfo, ExecInfo
|
||||||
from ..component_model.queue_types import BinaryEventTypes
|
from ..component_model.queue_types import BinaryEventTypes
|
||||||
|
|
||||||
|
|
||||||
@ -67,9 +67,8 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
sid: Optional[str] = None):
|
sid: Optional[str] = None):
|
||||||
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
|
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
|
||||||
|
|
||||||
def queue_updated(self):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
# todo: this should gather the global queue data
|
self.send_sync("status", StatusMessage(status=QueueInfo(exec_info=ExecInfo(queue_remaining=queue_remaining))))
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressHandlers:
|
class ProgressHandlers:
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
return None
|
return None
|
||||||
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||||
if self._caller_server is not None:
|
if self._caller_server is not None:
|
||||||
self._caller_server.queue_updated()
|
self._caller_server.queue_updated(self.get_tasks_remaining())
|
||||||
try:
|
try:
|
||||||
if "token" in queue_item.extra_data:
|
if "token" in queue_item.extra_data:
|
||||||
user_token = queue_item.extra_data["token"]
|
user_token = queue_item.extra_data["token"]
|
||||||
@ -75,7 +75,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
reply = RpcReply(**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
|
reply = RpcReply(**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||||
self._caller_history.put(queue_item, reply.outputs, reply.status)
|
self._caller_history.put(queue_item, reply.outputs, reply.status)
|
||||||
if self._caller_server is not None:
|
if self._caller_server is not None:
|
||||||
self._caller_server.queue_updated()
|
self._caller_server.queue_updated(self.get_tasks_remaining())
|
||||||
|
|
||||||
# if this has a completion future, complete it
|
# if this has a completion future, complete it
|
||||||
if queue_item.completed is not None:
|
if queue_item.completed is not None:
|
||||||
@ -86,7 +86,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
as_exec_exc = ExecutionError(queue_item.prompt_id, exceptions=[exc])
|
as_exec_exc = ExecutionError(queue_item.prompt_id, exceptions=[exc])
|
||||||
self._caller_history.put(queue_item, outputs={}, status=as_exec_exc.status)
|
self._caller_history.put(queue_item, outputs={}, status=as_exec_exc.status)
|
||||||
|
|
||||||
# if we have a completer, propoagate the exception to it
|
# if we have a completer, propagate the exception to it
|
||||||
if queue_item.completed is not None:
|
if queue_item.completed is not None:
|
||||||
queue_item.completed.set_exception(as_exec_exc)
|
queue_item.completed.set_exception(as_exec_exc)
|
||||||
raise as_exec_exc
|
raise as_exec_exc
|
||||||
@ -95,7 +95,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
if self._caller_server is not None:
|
if self._caller_server is not None:
|
||||||
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
|
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
|
||||||
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id}, self._caller_server.client_id)
|
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id}, self._caller_server.client_id)
|
||||||
self._caller_server.queue_updated()
|
self._caller_server.queue_updated(self.get_tasks_remaining())
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def put(self, item: QueueItem):
|
def put(self, item: QueueItem):
|
||||||
|
|||||||
@ -69,8 +69,7 @@ class DistributedPromptWorker:
|
|||||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
||||||
|
|
||||||
if self._embedded_comfy_client is None:
|
if self._embedded_comfy_client is None:
|
||||||
self._embedded_comfy_client = EmbeddedComfyClient(
|
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
||||||
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
|
||||||
if not self._embedded_comfy_client.is_running:
|
if not self._embedded_comfy_client.is_running:
|
||||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
|
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
|
||||||
from ..component_model.queue_types import BinaryEventTypes
|
from ..component_model.queue_types import BinaryEventTypes
|
||||||
@ -23,5 +23,5 @@ class ServerStub(ExecutorToClientProgress):
|
|||||||
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
|
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def queue_updated(self):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -120,6 +120,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
|||||||
logging.warning(f"Could not retrieve file {str(known_file)}")
|
logging.warning(f"Could not retrieve file {str(known_file)}")
|
||||||
else:
|
else:
|
||||||
destination_with_filename = join(this_model_directory, save_filename)
|
destination_with_filename = join(this_model_directory, save_filename)
|
||||||
|
os.makedirs(os.path.dirname(destination_with_filename), exist_ok=True)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
with _session.get(url, stream=True, allow_redirects=True) as response:
|
with _session.get(url, stream=True, allow_redirects=True) as response:
|
||||||
@ -300,6 +301,12 @@ KNOWN_CONTROLNETS = [
|
|||||||
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/hed/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_hed.bin", repo_type="space"),
|
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/hed/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_hed.bin", repo_type="space"),
|
||||||
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/lineart/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_lineart.bin", repo_type="space"),
|
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/lineart/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_lineart.bin", repo_type="space"),
|
||||||
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/seg/controlnet/diffusion_pytorch_model.safetensors", save_with_filename="ControlNet-Plus-Plus_sd15_ade20k_seg.safetensors", repo_type="space"),
|
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/seg/controlnet/diffusion_pytorch_model.safetensors", save_with_filename="ControlNet-Plus-Plus_sd15_ade20k_seg.safetensors", repo_type="space"),
|
||||||
|
HuggingFile("xinsir/controlnet-scribble-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-scribble-sdxl-1.0.safetensors"),
|
||||||
|
HuggingFile("xinsir/controlnet-canny-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-canny-sdxl-1.0.safetensors"),
|
||||||
|
HuggingFile("xinsir/controlnet-canny-sdxl-1.0", "diffusion_pytorch_model_V2.safetensors", save_with_filename="xinsir-controlnet-canny-sdxl-1.0_V2.safetensors"),
|
||||||
|
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
|
||||||
|
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
|
||||||
|
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
||||||
]
|
]
|
||||||
|
|
||||||
KNOWN_DIFF_CONTROLNETS = [
|
KNOWN_DIFF_CONTROLNETS = [
|
||||||
@ -343,12 +350,17 @@ KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
def add_known_models(folder_name: str, known_models: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||||
|
if len(models) < 1:
|
||||||
|
return known_models
|
||||||
|
|
||||||
if args.disable_known_models:
|
if args.disable_known_models:
|
||||||
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
|
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
|
||||||
symbol += models
|
|
||||||
|
pre_existing = frozenset(known_models)
|
||||||
|
known_models += [model for model in models if model not in pre_existing]
|
||||||
folder_paths.invalidate_cache(folder_name)
|
folder_paths.invalidate_cache(folder_name)
|
||||||
return symbol
|
return known_models
|
||||||
|
|
||||||
|
|
||||||
def huggingface_repos() -> List[str]:
|
def huggingface_repos() -> List[str]:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Optional, List, Sequence
|
|||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass(frozen=True)
|
||||||
class CivitFile:
|
class CivitFile:
|
||||||
"""
|
"""
|
||||||
A file on CivitAI
|
A file on CivitAI
|
||||||
@ -35,7 +35,7 @@ class CivitFile:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass(frozen=True)
|
||||||
class HuggingFile:
|
class HuggingFile:
|
||||||
"""
|
"""
|
||||||
A file on Huggingface Hub
|
A file on Huggingface Hub
|
||||||
@ -48,7 +48,7 @@ class HuggingFile:
|
|||||||
repo_id: str
|
repo_id: str
|
||||||
filename: str
|
filename: str
|
||||||
save_with_filename: Optional[str] = None
|
save_with_filename: Optional[str] = None
|
||||||
alternate_filenames: List[str] = dataclasses.field(default_factory=list)
|
alternate_filenames: Sequence[str] = dataclasses.field(default_factory=tuple)
|
||||||
show_in_ui: Optional[bool] = True
|
show_in_ui: Optional[bool] = True
|
||||||
convert_to_16_bit: Optional[bool] = False
|
convert_to_16_bit: Optional[bool] = False
|
||||||
size: Optional[int] = None
|
size: Optional[int] = None
|
||||||
|
|||||||
52
comfy_extras/nodes/nodes_skeletonize.py
Normal file
52
comfy_extras/nodes/nodes_skeletonize.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
from skimage.morphology import skeletonize, thin
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class SkeletonizeThin:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"binary_threshold": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}),
|
||||||
|
"approach": (["skeletonize", "thinning"], {}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process_image"
|
||||||
|
CATEGORY = "image/preprocessors"
|
||||||
|
|
||||||
|
def process_image(self, image, binary_threshold, approach):
|
||||||
|
use_skeletonize = approach == "skeletonize"
|
||||||
|
use_thinning = approach == "thinning"
|
||||||
|
device = comfy.model_management.intermediate_device()
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
batch_size, height, width, channels = image.shape
|
||||||
|
if channels == 3:
|
||||||
|
image = torch.mean(image, dim=-1, keepdim=True)
|
||||||
|
binary = (image > binary_threshold).float()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for img in binary:
|
||||||
|
img_np = img.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
if use_skeletonize:
|
||||||
|
result = skeletonize(img_np)
|
||||||
|
elif use_thinning:
|
||||||
|
result = thin(img_np)
|
||||||
|
else:
|
||||||
|
result = img_np
|
||||||
|
|
||||||
|
result = torch.from_numpy(result).float().to(device)
|
||||||
|
result = result.unsqueeze(-1).repeat(1, 1, 3)
|
||||||
|
results.append(result)
|
||||||
|
final_result = torch.stack(results).to(comfy.model_management.intermediate_device())
|
||||||
|
return (final_result,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SkeletonizeThin": SkeletonizeThin,
|
||||||
|
}
|
||||||
@ -3,3 +3,4 @@ markers =
|
|||||||
inference: mark as inference test (deselect with '-m "not inference"')
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
addopts = -s
|
addopts = -s
|
||||||
|
asyncio_mode = auto
|
||||||
@ -1,43 +0,0 @@
|
|||||||
import os
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Command line arguments for pytest
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
|
|
||||||
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
|
|
||||||
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
|
|
||||||
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
|
|
||||||
|
|
||||||
# This initializes args at the beginning of the test session
|
|
||||||
@pytest.fixture(scope="session", autouse=False)
|
|
||||||
def args_pytest(pytestconfig):
|
|
||||||
args = {}
|
|
||||||
args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
|
|
||||||
args['test_dir'] = pytestconfig.getoption('test_dir')
|
|
||||||
args['metrics_file'] = pytestconfig.getoption('metrics_file')
|
|
||||||
args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
|
|
||||||
|
|
||||||
# Initialize metrics file
|
|
||||||
with open(args['metrics_file'], 'a') as f:
|
|
||||||
# if file is empty, write header
|
|
||||||
if os.stat(args['metrics_file']).st_size == 0:
|
|
||||||
f.write("| date | run | file | status | value | \n")
|
|
||||||
f.write("| --- | --- | --- | --- | --- | \n")
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def gather_file_basenames(directory: str):
|
|
||||||
files = []
|
|
||||||
if not os.path.isdir(directory):
|
|
||||||
return files
|
|
||||||
for file in os.listdir(directory):
|
|
||||||
if file.endswith(".png"):
|
|
||||||
files.append(file)
|
|
||||||
return files
|
|
||||||
|
|
||||||
# Creates the list of baseline file names to use as a fixture
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "baseline_fname" in metafunc.fixturenames:
|
|
||||||
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
|
|
||||||
metafunc.parametrize("baseline_fname", baseline_fnames)
|
|
||||||
@ -1,203 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
import pytest
|
|
||||||
from pytest import fixture
|
|
||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
|
||||||
from skimage.metrics import structural_similarity as ssim
|
|
||||||
|
|
||||||
"""
|
|
||||||
This test suite compares images in 2 directories by file name
|
|
||||||
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ssim: Structural Similarity Index
|
|
||||||
# Returns a tuple of (ssim, diff_image)
|
|
||||||
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
|
||||||
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
|
|
||||||
# rescale the difference image to 0-255 range
|
|
||||||
diff = (diff * 255).astype("uint8")
|
|
||||||
return score, diff
|
|
||||||
|
|
||||||
|
|
||||||
# Metrics must return a tuple of (score, diff_image)
|
|
||||||
METRICS = {"ssim": ssim_score}
|
|
||||||
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompareImageMetrics:
|
|
||||||
@fixture(scope="class")
|
|
||||||
def test_file_names(self, args_pytest):
|
|
||||||
test_dir = args_pytest['test_dir']
|
|
||||||
fnames = self.gather_file_basenames(test_dir)
|
|
||||||
yield fnames
|
|
||||||
del fnames
|
|
||||||
|
|
||||||
@fixture(scope="class")
|
|
||||||
def teardown(self, args_pytest):
|
|
||||||
yield
|
|
||||||
# Runs after all tests are complete
|
|
||||||
# Aggregate output files into a grid of images
|
|
||||||
baseline_dir = args_pytest['baseline_dir']
|
|
||||||
test_dir = args_pytest['test_dir']
|
|
||||||
img_output_dir = args_pytest['img_output_dir']
|
|
||||||
metrics_file = args_pytest['metrics_file']
|
|
||||||
|
|
||||||
grid_dir = os.path.join(img_output_dir, "grid")
|
|
||||||
os.makedirs(grid_dir, exist_ok=True)
|
|
||||||
|
|
||||||
for metric_dir in METRICS.keys():
|
|
||||||
metric_path = os.path.join(img_output_dir, metric_dir)
|
|
||||||
for file in os.listdir(metric_path):
|
|
||||||
if file.endswith(".png"):
|
|
||||||
score = self.lookup_score_from_fname(file, metrics_file)
|
|
||||||
image_file_list = []
|
|
||||||
image_file_list.append([
|
|
||||||
os.path.join(baseline_dir, file),
|
|
||||||
os.path.join(test_dir, file),
|
|
||||||
os.path.join(metric_path, file)
|
|
||||||
])
|
|
||||||
# Create grid
|
|
||||||
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
|
||||||
grid = self.image_grid(image_list)
|
|
||||||
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
|
||||||
|
|
||||||
# Tests run for each baseline file name
|
|
||||||
@fixture()
|
|
||||||
def fname(self, baseline_fname, teardown):
|
|
||||||
yield baseline_fname
|
|
||||||
del baseline_fname
|
|
||||||
|
|
||||||
# For a baseline image file, finds the corresponding file name in test_dir and
|
|
||||||
# compares the images using the metrics in METRICS
|
|
||||||
@pytest.mark.parametrize("metric", METRICS.keys())
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
|
||||||
def test_pipeline_compare(
|
|
||||||
self,
|
|
||||||
args_pytest,
|
|
||||||
fname,
|
|
||||||
test_file_names,
|
|
||||||
metric,
|
|
||||||
teardown,
|
|
||||||
):
|
|
||||||
baseline_dir = args_pytest['baseline_dir']
|
|
||||||
|
|
||||||
test_dir = args_pytest['test_dir']
|
|
||||||
metrics_output_file = args_pytest['metrics_file']
|
|
||||||
img_output_dir = args_pytest['img_output_dir']
|
|
||||||
|
|
||||||
if not os.path.isdir(baseline_dir):
|
|
||||||
pytest.skip("Baseline directory does not exist")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not os.path.isdir(test_dir):
|
|
||||||
pytest.skip("Test directory does not exist")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
|
||||||
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
|
||||||
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
|
||||||
file_match = self.find_file_match(baseline_file_path, file_paths)
|
|
||||||
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
|
||||||
|
|
||||||
baseline_file_path = os.path.join(baseline_dir, fname)
|
|
||||||
|
|
||||||
# Find file match
|
|
||||||
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
|
|
||||||
test_file = self.find_file_match(baseline_file_path, file_paths)
|
|
||||||
|
|
||||||
# Run metrics
|
|
||||||
sample_baseline = self.read_img(baseline_file_path)
|
|
||||||
sample_secondary = self.read_img(test_file)
|
|
||||||
|
|
||||||
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
|
||||||
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
|
||||||
|
|
||||||
# Save metric values
|
|
||||||
with open(metrics_output_file, 'a') as f:
|
|
||||||
run_info = os.path.splitext(fname)[0]
|
|
||||||
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
|
|
||||||
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
|
|
||||||
|
|
||||||
# Save metric image
|
|
||||||
metric_img_dir = os.path.join(img_output_dir, metric)
|
|
||||||
os.makedirs(metric_img_dir, exist_ok=True)
|
|
||||||
output_filename = f'{fname}'
|
|
||||||
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
|
|
||||||
|
|
||||||
assert score > METRICS_PASS_THRESHOLD[metric]
|
|
||||||
|
|
||||||
def read_img(self, filename: str) -> np.ndarray:
|
|
||||||
cvImg = imread(filename)
|
|
||||||
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
|
||||||
return cvImg
|
|
||||||
|
|
||||||
def image_grid(self, img_list: list[list[Image.Image]]):
|
|
||||||
# imgs is a 2D list of images
|
|
||||||
# Assumes the input images are a rectangular grid of equal sized images
|
|
||||||
rows = len(img_list)
|
|
||||||
cols = len(img_list[0])
|
|
||||||
|
|
||||||
w, h = img_list[0][0].size
|
|
||||||
grid = Image.new('RGB', size=(cols * w, rows * h))
|
|
||||||
|
|
||||||
for i, row in enumerate(img_list):
|
|
||||||
for j, img in enumerate(row):
|
|
||||||
grid.paste(img, box=(j * w, i * h))
|
|
||||||
return grid
|
|
||||||
|
|
||||||
def lookup_score_from_fname(self,
|
|
||||||
fname: str,
|
|
||||||
metrics_output_file: str
|
|
||||||
) -> float:
|
|
||||||
fname_basestr = os.path.splitext(fname)[0]
|
|
||||||
with open(metrics_output_file, 'r') as f:
|
|
||||||
for line in f:
|
|
||||||
if fname_basestr in line:
|
|
||||||
score = float(line.split('|')[5])
|
|
||||||
return score
|
|
||||||
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
|
||||||
|
|
||||||
def gather_file_basenames(self, directory: str):
|
|
||||||
files = []
|
|
||||||
for file in os.listdir(directory):
|
|
||||||
if file.endswith(".png"):
|
|
||||||
files.append(file)
|
|
||||||
return files
|
|
||||||
|
|
||||||
def read_file_prompt(self, fname: str) -> str:
|
|
||||||
# Read prompt from image file metadata
|
|
||||||
img = Image.open(fname)
|
|
||||||
img.load()
|
|
||||||
return img.info['prompt']
|
|
||||||
|
|
||||||
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
|
||||||
# Find a file in file_paths with matching metadata to baseline_file
|
|
||||||
baseline_prompt = self.read_file_prompt(baseline_file)
|
|
||||||
|
|
||||||
# Do not match empty prompts
|
|
||||||
if baseline_prompt is None or baseline_prompt == "":
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Find file match
|
|
||||||
# Reorder test_file_names so that the file with matching name is first
|
|
||||||
# This is an optimization because matching file names are more likely
|
|
||||||
# to have matching metadata if they were generated with the same script
|
|
||||||
basename = os.path.basename(baseline_file)
|
|
||||||
file_path_basenames = [os.path.basename(f) for f in file_paths]
|
|
||||||
if basename in file_path_basenames:
|
|
||||||
match_index = file_path_basenames.index(basename)
|
|
||||||
file_paths.insert(0, file_paths.pop(match_index))
|
|
||||||
|
|
||||||
for f in file_paths:
|
|
||||||
test_file_prompt = self.read_file_prompt(f)
|
|
||||||
if baseline_prompt == test_file_prompt:
|
|
||||||
return f
|
|
||||||
@ -1,14 +1,43 @@
|
|||||||
import json
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Tuple
|
from typing import Tuple, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
|
|
||||||
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
|
|
||||||
|
def get_lan_ip():
|
||||||
|
"""
|
||||||
|
Finds the host's IP address on the LAN it's connected to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The IP address of the host on the LAN.
|
||||||
|
"""
|
||||||
|
# Create a dummy socket
|
||||||
|
s = None
|
||||||
|
try:
|
||||||
|
# Connect to a dummy address (Here, Google's public DNS server)
|
||||||
|
# The actual connection is not made, but this allows finding out the LAN IP
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
s.connect(("8.8.8.8", 80))
|
||||||
|
ip = s.getsockname()[0]
|
||||||
|
finally:
|
||||||
|
if s is not None:
|
||||||
|
s.close()
|
||||||
|
return ip
|
||||||
|
|
||||||
|
|
||||||
def run_server(server_arguments: Configuration):
|
def run_server(server_arguments: Configuration):
|
||||||
from comfy.cmd.main import main
|
from comfy.cmd.main import main
|
||||||
@ -20,7 +49,83 @@ def run_server(server_arguments: Configuration):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
def comfy_background_server(tmp_path) -> Tuple[Configuration, multiprocessing.Process]:
|
def has_gpu() -> bool:
|
||||||
|
# ipex
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
has_gpu = ipex.xpu.device_count() > 0
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
has_gpu = torch.device(torch.cuda.current_device()) is not None
|
||||||
|
except:
|
||||||
|
has_gpu = False
|
||||||
|
|
||||||
|
if has_gpu:
|
||||||
|
from comfy import model_management
|
||||||
|
from comfy.model_management import CPUState
|
||||||
|
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
|
||||||
|
yield has_gpu
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
|
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
||||||
|
"""
|
||||||
|
starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
|
processes_to_close: List[subprocess.Popen] = []
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
frontend_command = [
|
||||||
|
"comfyui",
|
||||||
|
"--listen=0.0.0.0",
|
||||||
|
"--port=9001",
|
||||||
|
"--cpu",
|
||||||
|
"--distributed-queue-frontend",
|
||||||
|
f"-w={str(tmp_path)}",
|
||||||
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
]
|
||||||
|
|
||||||
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
backend_command = [
|
||||||
|
"comfyui-worker",
|
||||||
|
"--port=9002",
|
||||||
|
f"-w={str(tmp_path)}",
|
||||||
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
]
|
||||||
|
|
||||||
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
try:
|
||||||
|
server_address = f"http://{get_lan_ip()}:9001"
|
||||||
|
start_time = time.time()
|
||||||
|
connected = False
|
||||||
|
while time.time() - start_time < 60:
|
||||||
|
try:
|
||||||
|
response = requests.get(server_address)
|
||||||
|
if response.status_code == 200:
|
||||||
|
connected = True
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logging.warning("", exc_info=exc)
|
||||||
|
time.sleep(1)
|
||||||
|
if not connected:
|
||||||
|
raise RuntimeError("could not connect to frontend")
|
||||||
|
yield server_address
|
||||||
|
finally:
|
||||||
|
for process in processes_to_close:
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
|
def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiprocessing.Process]:
|
||||||
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
import torch
|
import torch
|
||||||
# Start server
|
# Start server
|
||||||
|
|
||||||
@ -99,7 +204,7 @@ def model(clip):
|
|||||||
pytest.skip(f"{checkpoint} not present on machine")
|
pytest.skip(f"{checkpoint} not present on machine")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
@ -109,7 +214,7 @@ def use_temporary_output_directory(tmp_path: pathlib.Path):
|
|||||||
folder_paths.set_output_directory(orig_dir)
|
folder_paths.set_output_directory(orig_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
@ -22,9 +14,6 @@ from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocat
|
|||||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
|
||||||
os.environ["TC_HOST"] = "localhost"
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_prompt() -> QueueItem:
|
def create_test_prompt() -> QueueItem:
|
||||||
from comfy.cmd.execution import validate_prompt
|
from comfy.cmd.execution import validate_prompt
|
||||||
@ -103,68 +92,19 @@ async def test_distributed_prompt_queues_same_process():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_frontend_backend_workers():
|
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
|
||||||
processes_to_close: List[subprocess.Popen] = []
|
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
try:
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
params = rabbitmq.get_connection_params()
|
len_queue_after = await client.len_queue()
|
||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
assert len_queue_after == 0
|
||||||
|
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
||||||
frontend_command = [
|
|
||||||
"comfyui",
|
|
||||||
"--listen=0.0.0.0",
|
|
||||||
"--port=9001",
|
|
||||||
"--cpu",
|
|
||||||
"--distributed-queue-frontend",
|
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
|
||||||
]
|
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
|
||||||
backend_command = [
|
|
||||||
"comfyui-worker",
|
|
||||||
"--port=9002",
|
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
|
||||||
]
|
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
|
||||||
server_address = f"http://{get_lan_ip()}:9001"
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < 60:
|
|
||||||
try:
|
|
||||||
response = requests.get(server_address)
|
|
||||||
if response.status_code == 200:
|
|
||||||
break
|
|
||||||
except ConnectionRefusedError:
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
logging.warning("", exc_info=exc)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
client = AsyncRemoteComfyClient(server_address=server_address)
|
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
|
||||||
png_image_bytes = await client.queue_prompt(prompt)
|
|
||||||
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
|
||||||
finally:
|
|
||||||
for process in processes_to_close:
|
|
||||||
process.terminate()
|
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
@pytest.mark.asyncio
|
||||||
"""
|
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
|
||||||
Finds the host's IP address on the LAN it's connected to.
|
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||||
|
|
||||||
Returns:
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||||
str: The IP address of the host on the LAN.
|
with pytest.raises(Exception):
|
||||||
"""
|
await client.queue_prompt(prompt)
|
||||||
# Create a dummy socket
|
|
||||||
s = None
|
|
||||||
try:
|
|
||||||
# Connect to a dummy address (Here, Google's public DNS server)
|
|
||||||
# The actual connection is not made, but this allows finding out the LAN IP
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
s.connect(("8.8.8.8", 80))
|
|
||||||
ip = s.getsockname()[0]
|
|
||||||
finally:
|
|
||||||
if s is not None:
|
|
||||||
s.close()
|
|
||||||
return ip
|
|
||||||
|
|||||||
@ -1,144 +0,0 @@
|
|||||||
{
|
|
||||||
"4": {
|
|
||||||
"inputs": {
|
|
||||||
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
|
||||||
},
|
|
||||||
"class_type": "CheckpointLoaderSimple"
|
|
||||||
},
|
|
||||||
"5": {
|
|
||||||
"inputs": {
|
|
||||||
"width": 1024,
|
|
||||||
"height": 1024,
|
|
||||||
"batch_size": 1
|
|
||||||
},
|
|
||||||
"class_type": "EmptyLatentImage"
|
|
||||||
},
|
|
||||||
"6": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "a photo of a cat",
|
|
||||||
"clip": [
|
|
||||||
"4",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode"
|
|
||||||
},
|
|
||||||
"10": {
|
|
||||||
"inputs": {
|
|
||||||
"add_noise": "enable",
|
|
||||||
"noise_seed": 42,
|
|
||||||
"steps": 20,
|
|
||||||
"cfg": 7.5,
|
|
||||||
"sampler_name": "euler",
|
|
||||||
"scheduler": "normal",
|
|
||||||
"start_at_step": 0,
|
|
||||||
"end_at_step": 32,
|
|
||||||
"return_with_leftover_noise": "enable",
|
|
||||||
"model": [
|
|
||||||
"4",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"positive": [
|
|
||||||
"6",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"negative": [
|
|
||||||
"15",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"latent_image": [
|
|
||||||
"5",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "KSamplerAdvanced"
|
|
||||||
},
|
|
||||||
"12": {
|
|
||||||
"inputs": {
|
|
||||||
"samples": [
|
|
||||||
"14",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"vae": [
|
|
||||||
"4",
|
|
||||||
2
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "VAEDecode"
|
|
||||||
},
|
|
||||||
"13": {
|
|
||||||
"inputs": {
|
|
||||||
"filename_prefix": "test_inference",
|
|
||||||
"images": [
|
|
||||||
"12",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "SaveImage"
|
|
||||||
},
|
|
||||||
"14": {
|
|
||||||
"inputs": {
|
|
||||||
"add_noise": "disable",
|
|
||||||
"noise_seed": 42,
|
|
||||||
"steps": 20,
|
|
||||||
"cfg": 7.5,
|
|
||||||
"sampler_name": "euler",
|
|
||||||
"scheduler": "normal",
|
|
||||||
"start_at_step": 32,
|
|
||||||
"end_at_step": 10000,
|
|
||||||
"return_with_leftover_noise": "disable",
|
|
||||||
"model": [
|
|
||||||
"16",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"positive": [
|
|
||||||
"17",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"negative": [
|
|
||||||
"20",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"latent_image": [
|
|
||||||
"10",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "KSamplerAdvanced"
|
|
||||||
},
|
|
||||||
"15": {
|
|
||||||
"inputs": {
|
|
||||||
"conditioning": [
|
|
||||||
"6",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "ConditioningZeroOut"
|
|
||||||
},
|
|
||||||
"16": {
|
|
||||||
"inputs": {
|
|
||||||
"ckpt_name": "sd_xl_refiner_1.0.safetensors"
|
|
||||||
},
|
|
||||||
"class_type": "CheckpointLoaderSimple"
|
|
||||||
},
|
|
||||||
"17": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "a photo of a cat",
|
|
||||||
"clip": [
|
|
||||||
"16",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode"
|
|
||||||
},
|
|
||||||
"20": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "",
|
|
||||||
"clip": [
|
|
||||||
"16",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,228 +0,0 @@
|
|||||||
from copy import deepcopy
|
|
||||||
from io import BytesIO
|
|
||||||
from urllib import request
|
|
||||||
import numpy
|
|
||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
import pytest
|
|
||||||
from pytest import fixture
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
from typing import Union
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
|
||||||
import uuid
|
|
||||||
import urllib.request
|
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
from comfy.sampler_names import SAMPLER_NAMES, SCHEDULER_NAMES
|
|
||||||
|
|
||||||
"""
|
|
||||||
These tests generate and save images through a range of parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ComfyGraph:
|
|
||||||
def __init__(self,
|
|
||||||
graph: dict,
|
|
||||||
sampler_nodes: list[str],
|
|
||||||
):
|
|
||||||
self.graph = graph
|
|
||||||
self.sampler_nodes = sampler_nodes
|
|
||||||
|
|
||||||
def set_prompt(self, prompt, negative_prompt=None):
|
|
||||||
# Sets the prompt for the sampler nodes (eg. base and refiner)
|
|
||||||
for node in self.sampler_nodes:
|
|
||||||
prompt_node = self.graph[node]['inputs']['positive'][0]
|
|
||||||
self.graph[prompt_node]['inputs']['text'] = prompt
|
|
||||||
if negative_prompt:
|
|
||||||
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
|
|
||||||
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
|
|
||||||
|
|
||||||
def set_sampler_name(self, sampler_name: str, ):
|
|
||||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
|
||||||
for node in self.sampler_nodes:
|
|
||||||
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
|
||||||
|
|
||||||
def set_scheduler(self, scheduler: str):
|
|
||||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
|
||||||
for node in self.sampler_nodes:
|
|
||||||
self.graph[node]['inputs']['scheduler'] = scheduler
|
|
||||||
|
|
||||||
def set_filename_prefix(self, prefix: str):
|
|
||||||
# sets the filename prefix for the save nodes
|
|
||||||
for node in self.graph:
|
|
||||||
if self.graph[node]['class_type'] == 'SaveImage':
|
|
||||||
self.graph[node]['inputs']['filename_prefix'] = prefix
|
|
||||||
|
|
||||||
|
|
||||||
class ComfyClient:
|
|
||||||
# From examples/websockets_api_example.py
|
|
||||||
|
|
||||||
def connect(self,
|
|
||||||
listen: str = '127.0.0.1',
|
|
||||||
port: Union[str, int] = 8188,
|
|
||||||
client_id: str = str(uuid.uuid4())
|
|
||||||
):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.server_address = f"{listen}:{port}"
|
|
||||||
ws = websocket.WebSocket()
|
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
|
||||||
self.ws = ws
|
|
||||||
|
|
||||||
def queue_prompt(self, prompt):
|
|
||||||
p = {"prompt": prompt, "client_id": self.client_id}
|
|
||||||
data = json.dumps(p).encode('utf-8')
|
|
||||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
|
||||||
return json.loads(urllib.request.urlopen(req).read())
|
|
||||||
|
|
||||||
def get_image(self, filename, subfolder, folder_type):
|
|
||||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
||||||
url_values = urllib.parse.urlencode(data)
|
|
||||||
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
|
||||||
return response.read()
|
|
||||||
|
|
||||||
def get_history(self, prompt_id):
|
|
||||||
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
|
||||||
return json.loads(response.read())
|
|
||||||
|
|
||||||
def get_images(self, graph, save=True):
|
|
||||||
prompt = graph
|
|
||||||
if not save:
|
|
||||||
# Replace save nodes with preview nodes
|
|
||||||
prompt_str = json.dumps(prompt)
|
|
||||||
prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
|
|
||||||
prompt = json.loads(prompt_str)
|
|
||||||
|
|
||||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
|
||||||
output_images = {}
|
|
||||||
while True:
|
|
||||||
out = self.ws.recv()
|
|
||||||
if isinstance(out, str):
|
|
||||||
message = json.loads(out)
|
|
||||||
if message['type'] == 'executing':
|
|
||||||
data = message['data']
|
|
||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
|
||||||
break # Execution is done
|
|
||||||
else:
|
|
||||||
continue # previews are binary data
|
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
|
||||||
for o in history['outputs']:
|
|
||||||
for node_id in history['outputs']:
|
|
||||||
node_output = history['outputs'][node_id]
|
|
||||||
if 'images' in node_output:
|
|
||||||
images_output = []
|
|
||||||
for image in node_output['images']:
|
|
||||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
|
||||||
images_output.append(image_data)
|
|
||||||
output_images[node_id] = images_output
|
|
||||||
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Initialize graphs
|
|
||||||
#
|
|
||||||
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
|
|
||||||
with open(default_graph_file, 'r') as file:
|
|
||||||
default_graph = json.loads(file.read())
|
|
||||||
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10', '14'])
|
|
||||||
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
|
|
||||||
|
|
||||||
#
|
|
||||||
# Loop through these variables
|
|
||||||
#
|
|
||||||
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
|
|
||||||
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
|
|
||||||
prompt_list = [
|
|
||||||
'a painting of a cat',
|
|
||||||
]
|
|
||||||
|
|
||||||
sampler_list = SAMPLER_NAMES[:]
|
|
||||||
scheduler_list = SCHEDULER_NAMES[:]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.inference
|
|
||||||
@pytest.mark.parametrize("sampler", sampler_list)
|
|
||||||
@pytest.mark.parametrize("scheduler", scheduler_list)
|
|
||||||
@pytest.mark.parametrize("prompt", prompt_list)
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
|
||||||
class TestInference:
|
|
||||||
#
|
|
||||||
# Initialize server and client
|
|
||||||
#
|
|
||||||
|
|
||||||
def start_client(self, listen: str, port: int):
|
|
||||||
# Start client
|
|
||||||
comfy_client = ComfyClient()
|
|
||||||
# Connect to server (with retries)
|
|
||||||
n_tries = 5
|
|
||||||
for i in range(n_tries):
|
|
||||||
time.sleep(4)
|
|
||||||
try:
|
|
||||||
comfy_client.connect(listen=listen, port=port)
|
|
||||||
except ConnectionRefusedError as e:
|
|
||||||
print(e)
|
|
||||||
print(f"({i + 1}/{n_tries}) Retrying...")
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return comfy_client
|
|
||||||
|
|
||||||
#
|
|
||||||
# Client and graph fixtures with server warmup
|
|
||||||
#
|
|
||||||
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
|
|
||||||
# The "graph" is the default graph
|
|
||||||
@fixture(scope="function", params=comfy_graph_list, ids=comfy_graph_ids, autouse=False)
|
|
||||||
def _client_graph(self, request, comfy_background_server) -> (ComfyClient, ComfyGraph):
|
|
||||||
configuration, _ = comfy_background_server
|
|
||||||
comfy_graph = request.param
|
|
||||||
# Start client
|
|
||||||
comfy_client = self.start_client(configuration.listen, configuration.port)
|
|
||||||
|
|
||||||
# Warm up pipeline
|
|
||||||
comfy_client.get_images(graph=comfy_graph.graph, save=False)
|
|
||||||
|
|
||||||
yield comfy_client, comfy_graph
|
|
||||||
del comfy_client
|
|
||||||
del comfy_graph
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
@fixture
|
|
||||||
def client(self, _client_graph):
|
|
||||||
client = _client_graph[0]
|
|
||||||
yield client
|
|
||||||
|
|
||||||
@fixture
|
|
||||||
def comfy_graph(self, _client_graph):
|
|
||||||
# avoid mutating the graph
|
|
||||||
graph = deepcopy(_client_graph[1])
|
|
||||||
yield graph
|
|
||||||
|
|
||||||
def test_comfy(
|
|
||||||
self,
|
|
||||||
client,
|
|
||||||
comfy_graph,
|
|
||||||
sampler,
|
|
||||||
scheduler,
|
|
||||||
prompt,
|
|
||||||
request
|
|
||||||
):
|
|
||||||
test_info = request.node.name
|
|
||||||
comfy_graph.set_filename_prefix(test_info)
|
|
||||||
# Settings for comfy graph
|
|
||||||
comfy_graph.set_sampler_name(sampler)
|
|
||||||
comfy_graph.set_scheduler(scheduler)
|
|
||||||
comfy_graph.set_prompt(prompt)
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
images = client.get_images(comfy_graph.graph)
|
|
||||||
|
|
||||||
assert len(images) != 0, "No images generated"
|
|
||||||
# assert all images are not blank
|
|
||||||
for images_output in images.values():
|
|
||||||
for image_data in images_output:
|
|
||||||
pil_image = Image.open(BytesIO(image_data))
|
|
||||||
assert numpy.array(pil_image).any() != 0, "Image is blank"
|
|
||||||
@ -1,25 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from comfy import model_management
|
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
||||||
from comfy.model_downloader_types import CivitFile
|
from comfy.model_downloader_types import CivitFile
|
||||||
from comfy.model_management import CPUState
|
|
||||||
|
|
||||||
try:
|
_workflows = {
|
||||||
has_gpu = torch.device(torch.cuda.current_device()) is not None
|
"lora_1": {
|
||||||
except:
|
|
||||||
has_gpu = False
|
|
||||||
|
|
||||||
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
|
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_gpu, reason="Expects GPU device")
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_lora_workflow():
|
|
||||||
prompt = Prompt.validate({
|
|
||||||
"3": {
|
"3": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"seed": 851616030078638,
|
"seed": 851616030078638,
|
||||||
@ -144,11 +132,30 @@ async def test_lora_workflow():
|
|||||||
"title": "Load LoRA"
|
"title": "Load LoRA"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
||||||
|
config = Configuration()
|
||||||
|
config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd"))
|
||||||
|
async with EmbeddedComfyClient(config) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("workflow_name, workflow", _workflows.items())
|
||||||
|
async def test_workflow(workflow_name: str, workflow: dict, has_gpu: bool, client: EmbeddedComfyClient):
|
||||||
|
if not has_gpu:
|
||||||
|
pytest.skip("requires gpu")
|
||||||
|
|
||||||
|
|
||||||
|
prompt = Prompt.validate(workflow)
|
||||||
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
|
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
|
||||||
async with EmbeddedComfyClient() as client:
|
# todo: add all the models we want to test a bit more elegantly
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
|
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||||
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
||||||
@ -122,6 +122,7 @@ def test_string_enum_request_parameter():
|
|||||||
# todo: check that a graph that uses this in a checkpoint is valid
|
# todo: check that a graph that uses this in a checkpoint is valid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("issues")
|
||||||
def test_hash_images():
|
def test_hash_images():
|
||||||
nt = HashImage.INPUT_TYPES()
|
nt = HashImage.INPUT_TYPES()
|
||||||
assert nt is not None
|
assert nt is not None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user