diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index 454172dd9..bfd67f9c7 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -1,7 +1,7 @@ import asyncio +import os from .. import options -from ..distributed.distributed_prompt_worker import DistributedPromptWorker options.enable_args_parsing() @@ -14,6 +14,38 @@ async def main(): args.distributed_queue_frontend = False assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server" + + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + print("Set cuda device to:", args.cuda_device) + + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + # configure paths + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) + print(f"Setting output directory to: {output_dir}") + from ..cmd import folder_paths + + folder_paths.set_output_directory(output_dir) + + if args.input_directory: + input_dir = os.path.abspath(args.input_directory) + print(f"Setting input directory to: {input_dir}") + from ..cmd import folder_paths + + folder_paths.set_input_directory(input_dir) + + if args.temp_directory: + temp_dir = os.path.abspath(args.temp_directory) + print(f"Setting temp directory to: {temp_dir}") + from ..cmd import folder_paths + + folder_paths.set_temp_directory(temp_dir) + + from ..distributed.distributed_prompt_worker import DistributedPromptWorker async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, queue_name=args.distributed_queue_name): stop = asyncio.Event() diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index f82d45f9a..ea6e38b43 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -1,6 +1,5 @@ import asyncio import logging -import sys from asyncio import AbstractEventLoop from contextlib import AsyncExitStack from dataclasses import asdict