Tweaks to distributed queueing

- Do not auto delete the queue
 - Make the queue durable
 - Progress notifications expire
 - Deprecation fix
This commit is contained in:
doctorpangloss 2024-11-14 15:08:59 -08:00
parent 44be2591df
commit 4150dbbbe5
5 changed files with 13 additions and 10 deletions

View File

@ -61,7 +61,12 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
# todo: user_id should never be none here
return
await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data})
try:
# we don't need to await this coroutine
_ = asyncio.create_task(self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}, expiration=1000))
except asyncio.TimeoutError:
# these can gracefully expire
pass
def send_sync(self,
event: SendSyncEvent,

View File

@ -275,7 +275,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
return
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=False, durable=True)
if self._is_caller:
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
# this makes the queue available to complete work items

View File

@ -10,9 +10,8 @@ KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates():
try:
_chat_templates: Traversable
with files(__package__) / "chat_templates" as _chat_templates:
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
_chat_templates: Traversable = files(__package__) / "chat_templates"
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)

View File

@ -197,7 +197,7 @@ async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_w
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
class TestWorker(DistributedPromptWorker):
class Worker(DistributedPromptWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.processed_workflows: set[str] = set()
@ -215,9 +215,9 @@ async def test_two_workers_distinct_requests():
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Start two test workers
workers: list[TestWorker] = []
workers: list[Worker] = []
for i in range(2):
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
worker = Worker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
await worker.init()
workers.append(worker)

View File

@ -14,7 +14,6 @@ from . import workflows
@pytest.fixture(scope="module", autouse=False)
@pytest.mark.asyncio
async def client(tmp_path_factory) -> EmbeddedComfyClient:
async with EmbeddedComfyClient() as client:
yield client