mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Tweaks to distributed queueing
- Do not auto delete the queue - Make the queue durable - Progress notifications expire - Deprecation fix
This commit is contained in:
parent
44be2591df
commit
4150dbbbe5
@ -61,7 +61,12 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
# todo: user_id should never be none here
|
# todo: user_id should never be none here
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data})
|
try:
|
||||||
|
# we don't need to await this coroutine
|
||||||
|
_ = asyncio.create_task(self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}, expiration=1000))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# these can gracefully expire
|
||||||
|
pass
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: SendSyncEvent,
|
event: SendSyncEvent,
|
||||||
|
|||||||
@ -275,7 +275,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
return
|
return
|
||||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||||
self._channel = await self._connection.channel()
|
self._channel = await self._connection.channel()
|
||||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=False, durable=True)
|
||||||
if self._is_caller:
|
if self._is_caller:
|
||||||
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
|
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
|
||||||
# this makes the queue available to complete work items
|
# this makes the queue available to complete work items
|
||||||
|
|||||||
@ -10,9 +10,8 @@ KNOWN_CHAT_TEMPLATES = {}
|
|||||||
|
|
||||||
def _update_known_chat_templates():
|
def _update_known_chat_templates():
|
||||||
try:
|
try:
|
||||||
_chat_templates: Traversable
|
_chat_templates: Traversable = files(__package__) / "chat_templates"
|
||||||
with files(__package__) / "chat_templates" as _chat_templates:
|
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
||||||
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
||||||
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
||||||
|
|||||||
@ -197,7 +197,7 @@ async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_w
|
|||||||
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
||||||
|
|
||||||
|
|
||||||
class TestWorker(DistributedPromptWorker):
|
class Worker(DistributedPromptWorker):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.processed_workflows: set[str] = set()
|
self.processed_workflows: set[str] = set()
|
||||||
@ -215,9 +215,9 @@ async def test_two_workers_distinct_requests():
|
|||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
# Start two test workers
|
# Start two test workers
|
||||||
workers: list[TestWorker] = []
|
workers: list[Worker] = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
|
worker = Worker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
|
||||||
await worker.init()
|
await worker.init()
|
||||||
workers.append(worker)
|
workers.append(worker)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from . import workflows
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=False)
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
||||||
async with EmbeddedComfyClient() as client:
|
async with EmbeddedComfyClient() as client:
|
||||||
yield client
|
yield client
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user