diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 4847332d8..1a122b28d 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -61,7 +61,12 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): # todo: user_id should never be none here return - await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}) + try: + # we don't need to await this coroutine + _ = asyncio.create_task(self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}, expiration=1000)) + except asyncio.TimeoutError: + # these can gracefully expire + pass def send_sync(self, event: SendSyncEvent, diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 7566f20ec..d55fe7526 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -275,7 +275,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): return self._connection = await connect_robust(self._connection_uri, loop=self._loop) self._channel = await self._connection.channel() - self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False) + self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=False, durable=True) if self._is_caller: self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name) # this makes the queue available to complete work items diff --git a/comfy/language/chat_templates.py b/comfy/language/chat_templates.py index 260aad95f..78838d064 100644 --- a/comfy/language/chat_templates.py +++ b/comfy/language/chat_templates.py @@ -10,9 +10,8 @@ KNOWN_CHAT_TEMPLATES = {} def _update_known_chat_templates(): try: - _chat_templates: Traversable - with files(__package__) / "chat_templates" as _chat_templates: - _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} - KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) + _chat_templates: Traversable = files(__package__) / "chat_templates" + _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} + KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) except ImportError as exc: logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index ac5f6e528..81f969f81 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -197,7 +197,7 @@ async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_w pytest.fail("Failed to get a 200 response with valid data within the timeout period") -class TestWorker(DistributedPromptWorker): +class Worker(DistributedPromptWorker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.processed_workflows: set[str] = set() @@ -215,9 +215,9 @@ async def test_two_workers_distinct_requests(): connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" # Start two test workers - workers: list[TestWorker] = [] + workers: list[Worker] = [] for i in range(2): - worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1)) + worker = Worker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1)) await worker.init() workers.append(worker) diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index 3c2df3b46..98417e70f 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -14,7 +14,6 @@ from . import workflows @pytest.fixture(scope="module", autouse=False) -@pytest.mark.asyncio async def client(tmp_path_factory) -> EmbeddedComfyClient: async with EmbeddedComfyClient() as client: yield client