mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 16:20:17 +08:00
Compare commits
3 Commits
b2c274c534
...
43c527bd74
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43c527bd74 | ||
|
|
1a72bf2046 | ||
|
|
7cecb6dbf8 |
@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
@ -152,6 +152,8 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--use-subprocess-workers", action="store_true", help="Execute each prompt in an isolated subprocess with complete GPU/ROCm context reset. Ensures clean state between jobs but adds startup overhead.")
|
||||
parser.add_argument("--subprocess-timeout", type=int, default=600, help="Timeout in seconds for subprocess execution (default: 600, only used with --use-subprocess-workers).")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
|
||||
145
comfy/execution_core.py
Normal file
145
comfy/execution_core.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Core execution logic shared between normal and subprocess execution modes."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
_active_worker = None
|
||||
|
||||
|
||||
def create_worker(server_instance):
|
||||
"""Create worker backend. Returns NativeWorker or SubprocessWorker."""
|
||||
global _active_worker
|
||||
from comfy.cli_args import args
|
||||
|
||||
server = WorkerServer(server_instance)
|
||||
|
||||
if args.use_subprocess_workers:
|
||||
from comfy.worker_process import SubprocessWorker
|
||||
worker = SubprocessWorker(server, timeout=args.subprocess_timeout)
|
||||
else:
|
||||
from comfy.worker_native import NativeWorker
|
||||
worker = NativeWorker(server)
|
||||
|
||||
_active_worker = worker
|
||||
return worker
|
||||
|
||||
|
||||
async def init_execution_environment():
|
||||
"""Load nodes and custom nodes. Returns number of node types loaded."""
|
||||
import nodes
|
||||
from comfy.cli_args import args
|
||||
|
||||
await nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
return len(nodes.NODE_CLASS_MAPPINGS)
|
||||
|
||||
|
||||
def setup_progress_hook(server_instance, interrupt_checker):
|
||||
"""Set up global progress hook. interrupt_checker must raise on interrupt."""
|
||||
import comfy.utils
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
ctx = get_executing_context()
|
||||
if ctx:
|
||||
prompt_id = prompt_id or ctx.prompt_id
|
||||
node_id = node_id or ctx.node_id
|
||||
|
||||
interrupt_checker()
|
||||
|
||||
prompt_id = prompt_id or server_instance.last_prompt_id
|
||||
node_id = node_id or server_instance.last_node_id
|
||||
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
server_instance.send_sync("progress", {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
class WorkerServer:
|
||||
"""Protocol boundary: client_id, last_node_id, last_prompt_id, sockets_metadata, send_sync(), queue_updated()"""
|
||||
|
||||
_WRITABLE = {'client_id', 'last_node_id', 'last_prompt_id'}
|
||||
|
||||
def __init__(self, server):
|
||||
object.__setattr__(self, '_server', server)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._WRITABLE:
|
||||
setattr(self._server, name, value)
|
||||
else:
|
||||
raise AttributeError(f"WorkerServer does not accept attribute '{name}'")
|
||||
|
||||
@property
|
||||
def client_id(self):
|
||||
return self._server.client_id
|
||||
|
||||
@property
|
||||
def last_node_id(self):
|
||||
return self._server.last_node_id
|
||||
|
||||
@property
|
||||
def last_prompt_id(self):
|
||||
return self._server.last_prompt_id
|
||||
|
||||
@property
|
||||
def sockets_metadata(self):
|
||||
return self._server.sockets_metadata
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
self._server.send_sync(event, data, sid or self.client_id)
|
||||
|
||||
def queue_updated(self):
|
||||
self._server.queue_updated()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
_active_worker.interrupt(value)
|
||||
|
||||
|
||||
def _strip_sensitive(prompt):
|
||||
return prompt[:5] + prompt[6:]
|
||||
|
||||
|
||||
def prompt_worker(q, worker):
|
||||
"""Main prompt execution loop."""
|
||||
import execution
|
||||
|
||||
server = worker.server_instance
|
||||
|
||||
while True:
|
||||
queue_item = q.get(timeout=worker.get_gc_timeout())
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
server.last_prompt_id = prompt_id
|
||||
|
||||
extra_data = {**item[3], **item[5]}
|
||||
|
||||
result = worker.execute_prompt(item[2], prompt_id, extra_data, item[4], server=server)
|
||||
worker.mark_needs_gc()
|
||||
|
||||
q.task_done(
|
||||
item_id,
|
||||
result['history_result'],
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if result['success'] else 'error',
|
||||
completed=result['success'],
|
||||
messages=result['status_messages']
|
||||
),
|
||||
process_item=_strip_sensitive
|
||||
)
|
||||
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
if elapsed > 600:
|
||||
logging.info(f"Prompt executed in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}")
|
||||
else:
|
||||
logging.info(f"Prompt executed in {elapsed:.2f} seconds")
|
||||
|
||||
worker.handle_flags(q.get_flags())
|
||||
95
comfy/worker_native.py
Normal file
95
comfy/worker_native.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Native (in-process) worker for prompt execution."""
|
||||
|
||||
import time
|
||||
import gc
|
||||
|
||||
|
||||
class NativeWorker:
|
||||
"""Executes prompts in the same process as the server."""
|
||||
|
||||
def __init__(self, server_instance, interrupt_checker=None):
|
||||
self.server_instance = server_instance
|
||||
self.interrupt_checker = interrupt_checker
|
||||
self.executor = None
|
||||
self.last_gc_collect = 0
|
||||
self.need_gc = False
|
||||
self.gc_collect_interval = 10.0
|
||||
|
||||
async def initialize(self):
|
||||
"""Load nodes and set up executor. Returns node count."""
|
||||
from execution import PromptExecutor, CacheType
|
||||
from comfy.cli_args import args
|
||||
from comfy.execution_core import init_execution_environment, setup_progress_hook
|
||||
import comfy.model_management as mm
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
try:
|
||||
node_count = await init_execution_environment()
|
||||
finally:
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
interrupt_checker = self.interrupt_checker or mm.throw_exception_if_processing_interrupted
|
||||
setup_progress_hook(self.server_instance, interrupt_checker=interrupt_checker)
|
||||
|
||||
cache_type = CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = CacheType.NONE
|
||||
|
||||
self.executor = PromptExecutor(
|
||||
self.server_instance,
|
||||
cache_type=cache_type,
|
||||
cache_args={"lru": args.cache_lru, "ram": args.cache_ram}
|
||||
)
|
||||
return node_count
|
||||
|
||||
def execute_prompt(self, prompt, prompt_id, extra_data, execute_outputs, server=None):
|
||||
self.executor.execute(prompt, prompt_id, extra_data, execute_outputs)
|
||||
return {
|
||||
'success': self.executor.success,
|
||||
'history_result': self.executor.history_result,
|
||||
'status_messages': self.executor.status_messages,
|
||||
'prompt_id': prompt_id
|
||||
}
|
||||
|
||||
def handle_flags(self, flags):
|
||||
import comfy.model_management as mm
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
if flags.get("unload_models", free_memory):
|
||||
mm.unload_all_models()
|
||||
self.need_gc = True
|
||||
self.last_gc_collect = 0
|
||||
|
||||
if free_memory:
|
||||
if self.executor:
|
||||
self.executor.reset()
|
||||
self.need_gc = True
|
||||
self.last_gc_collect = 0
|
||||
|
||||
if self.need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - self.last_gc_collect) > self.gc_collect_interval:
|
||||
gc.collect()
|
||||
mm.soft_empty_cache()
|
||||
self.last_gc_collect = current_time
|
||||
self.need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
def interrupt(self, value=True):
|
||||
import comfy.model_management
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
def mark_needs_gc(self):
|
||||
self.need_gc = True
|
||||
|
||||
def get_gc_timeout(self):
|
||||
if self.need_gc:
|
||||
return max(self.gc_collect_interval - (time.perf_counter() - self.last_gc_collect), 0.0)
|
||||
return 1000.0
|
||||
179
comfy/worker_process.py
Normal file
179
comfy/worker_process.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""Subprocess worker for isolated prompt execution with complete GPU/ROCm reset."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
import traceback
|
||||
|
||||
mp.set_start_method('spawn', force=True)
|
||||
|
||||
|
||||
def _deserialize_preview(msg):
|
||||
"""Deserialize preview image from IPC transport."""
|
||||
if not (isinstance(msg['data'], dict) and msg['data'].get('_serialized')):
|
||||
return msg
|
||||
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
s = msg['data']
|
||||
pil_image = Image.open(BytesIO(base64.b64decode(s['image_bytes'])))
|
||||
msg['data'] = ((s['image_type'], pil_image, s['max_size']), s['metadata'])
|
||||
return msg
|
||||
|
||||
|
||||
def _error_result(worker_id, prompt_id, error, tb=None):
|
||||
return {
|
||||
'success': False,
|
||||
'error': error,
|
||||
'traceback': tb,
|
||||
'history_result': {},
|
||||
'status_messages': [],
|
||||
'worker_id': worker_id,
|
||||
'prompt_id': prompt_id
|
||||
}
|
||||
|
||||
|
||||
def _kill_worker(worker, worker_id):
|
||||
if not worker.is_alive():
|
||||
return
|
||||
worker.terminate()
|
||||
worker.join(timeout=2)
|
||||
if worker.is_alive():
|
||||
logging.warning(f"Worker {worker_id} didn't terminate, killing")
|
||||
worker.kill()
|
||||
worker.join()
|
||||
|
||||
|
||||
class SubprocessWorker:
|
||||
"""Executes each prompt in an isolated subprocess with fresh GPU context."""
|
||||
|
||||
def __init__(self, server_instance, timeout=600):
|
||||
self.server_instance = server_instance
|
||||
self.timeout = timeout
|
||||
self.worker_counter = 0
|
||||
self.current_worker = None
|
||||
self.interrupt_event = None
|
||||
logging.info("SubprocessWorker created - each job will run in isolated process")
|
||||
|
||||
async def initialize(self):
|
||||
"""Load node definitions for prompt validation. Returns node count."""
|
||||
from comfy.execution_core import init_execution_environment
|
||||
return await init_execution_environment()
|
||||
|
||||
def handle_flags(self, flags):
|
||||
pass
|
||||
|
||||
def mark_needs_gc(self):
|
||||
pass
|
||||
|
||||
def get_gc_timeout(self):
|
||||
return 1000.0
|
||||
|
||||
def interrupt(self, value=True):
|
||||
if not value:
|
||||
return
|
||||
if self.interrupt_event:
|
||||
self.interrupt_event.set()
|
||||
if self.current_worker and self.current_worker.is_alive():
|
||||
self.current_worker.join(timeout=2)
|
||||
_kill_worker(self.current_worker, self.worker_counter)
|
||||
self.current_worker = None
|
||||
|
||||
def _relay_messages(self, message_queue, server):
|
||||
"""Relay queued messages to UI."""
|
||||
while not message_queue.empty():
|
||||
try:
|
||||
msg = _deserialize_preview(message_queue.get_nowait())
|
||||
if server:
|
||||
server.send_sync(msg['event'], msg['data'], msg['sid'])
|
||||
except:
|
||||
break
|
||||
|
||||
def execute_prompt(self, prompt, prompt_id, extra_data={}, execute_outputs=[], server=None):
|
||||
self.worker_counter += 1
|
||||
worker_id = self.worker_counter
|
||||
|
||||
job_queue = mp.Queue()
|
||||
result_queue = mp.Queue()
|
||||
message_queue = mp.Queue()
|
||||
self.interrupt_event = mp.Event()
|
||||
|
||||
client_id = extra_data.get('client_id')
|
||||
client_metadata = {}
|
||||
if client_id and hasattr(server, 'sockets_metadata'):
|
||||
client_metadata = server.sockets_metadata.get(client_id, {})
|
||||
|
||||
job_data = {
|
||||
'prompt': prompt,
|
||||
'prompt_id': prompt_id,
|
||||
'extra_data': extra_data,
|
||||
'execute_outputs': execute_outputs,
|
||||
'client_sockets_metadata': client_metadata
|
||||
}
|
||||
|
||||
from comfy.worker_process_child import worker_main
|
||||
worker = mp.Process(
|
||||
target=worker_main,
|
||||
args=(job_queue, result_queue, message_queue, self.interrupt_event, worker_id),
|
||||
name=f'ComfyUI-Worker-{worker_id}'
|
||||
)
|
||||
|
||||
logging.info(f"Starting worker {worker_id} for prompt {prompt_id}")
|
||||
self.current_worker = worker
|
||||
worker.start()
|
||||
job_queue.put(job_data)
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = None
|
||||
|
||||
while result is None:
|
||||
if self.interrupt_event.is_set():
|
||||
logging.info(f"Worker {worker_id} interrupted")
|
||||
if server:
|
||||
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
|
||||
return _error_result(worker_id, prompt_id, 'Execution interrupted by user')
|
||||
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
self._relay_messages(message_queue, server)
|
||||
|
||||
try:
|
||||
result = result_queue.get(timeout=0.1)
|
||||
except mp.queues.Empty:
|
||||
pass
|
||||
|
||||
self._relay_messages(message_queue, server)
|
||||
|
||||
worker.join(timeout=5)
|
||||
if worker.is_alive():
|
||||
_kill_worker(worker, worker_id)
|
||||
|
||||
logging.info(f"Worker {worker_id} cleaned up (exit code: {worker.exitcode})")
|
||||
self.current_worker = None
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
error = f"Worker {worker_id} timed out after {self.timeout}s. Try --subprocess-timeout to increase."
|
||||
logging.error(error)
|
||||
_kill_worker(worker, worker_id)
|
||||
self.current_worker = None
|
||||
return _error_result(worker_id, prompt_id, error)
|
||||
|
||||
except Exception as e:
|
||||
error = f"Worker {worker_id} IPC error: {e}"
|
||||
logging.error(f"{error}\n{traceback.format_exc()}")
|
||||
_kill_worker(worker, worker_id)
|
||||
self.current_worker = None
|
||||
return _error_result(worker_id, prompt_id, error, traceback.format_exc())
|
||||
|
||||
finally:
|
||||
for q in (job_queue, result_queue, message_queue):
|
||||
q.close()
|
||||
try:
|
||||
q.join_thread()
|
||||
except:
|
||||
pass
|
||||
104
comfy/worker_process_child.py
Normal file
104
comfy/worker_process_child.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Subprocess worker child process entry point."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import traceback
|
||||
|
||||
|
||||
class IPCMessageServer:
|
||||
"""IPC-based message server for subprocess workers."""
|
||||
|
||||
def __init__(self, message_queue, client_id=None, sockets_metadata=None):
|
||||
self.message_queue = message_queue
|
||||
self.client_id = client_id
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.sockets_metadata = sockets_metadata or {}
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
from protocol import BinaryEventTypes
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
if event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA and isinstance(data, tuple):
|
||||
preview_image, metadata = data
|
||||
image_type, pil_image, max_size = preview_image
|
||||
|
||||
buffer = BytesIO()
|
||||
pil_image.save(buffer, format=image_type)
|
||||
|
||||
data = {
|
||||
'_serialized': True,
|
||||
'image_type': image_type,
|
||||
'image_bytes': base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
||||
'max_size': max_size,
|
||||
'metadata': metadata
|
||||
}
|
||||
|
||||
self.message_queue.put_nowait({'event': event, 'data': data, 'sid': sid})
|
||||
|
||||
def queue_updated(self):
|
||||
pass
|
||||
|
||||
|
||||
def worker_main(job_queue, result_queue, message_queue, interrupt_event, worker_id):
|
||||
"""Subprocess worker entry point - spawned fresh for each execution."""
|
||||
job_data = None
|
||||
try:
|
||||
logging.basicConfig(level=logging.INFO, format=f'[Worker-{worker_id}] %(levelname)s: %(message)s')
|
||||
logging.info(f"Worker {worker_id} starting (PID: {mp.current_process().pid})")
|
||||
|
||||
import asyncio
|
||||
import comfy.model_management
|
||||
from comfy.worker_native import NativeWorker
|
||||
from comfy.execution_core import WorkerServer
|
||||
|
||||
logging.info(f"Worker {worker_id} initialized. Device: {comfy.model_management.get_torch_device()}")
|
||||
|
||||
job_data = job_queue.get(timeout=30)
|
||||
client_id = job_data.get('extra_data', {}).get('client_id')
|
||||
client_metadata = job_data.get('client_sockets_metadata', {})
|
||||
|
||||
sockets_metadata = {client_id: client_metadata} if client_id and client_metadata else {}
|
||||
ipc_server = IPCMessageServer(message_queue, client_id, sockets_metadata)
|
||||
server = WorkerServer(ipc_server)
|
||||
|
||||
def check_interrupt():
|
||||
if interrupt_event.is_set():
|
||||
raise comfy.model_management.InterruptProcessingException()
|
||||
|
||||
worker = NativeWorker(server, interrupt_checker=check_interrupt)
|
||||
|
||||
import comfy.execution_core
|
||||
comfy.execution_core._active_worker = worker
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
node_count = loop.run_until_complete(worker.initialize())
|
||||
logging.info(f"Worker {worker_id} loaded {node_count} node types")
|
||||
|
||||
result = worker.execute_prompt(
|
||||
job_data['prompt'],
|
||||
job_data['prompt_id'],
|
||||
job_data.get('extra_data', {}),
|
||||
job_data.get('execute_outputs', [])
|
||||
)
|
||||
result['worker_id'] = worker_id
|
||||
|
||||
logging.info(f"Worker {worker_id} completed successfully")
|
||||
result_queue.put(result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Worker {worker_id} failed: {e}\n{traceback.format_exc()}")
|
||||
result_queue.put({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'traceback': traceback.format_exc(),
|
||||
'history_result': {},
|
||||
'status_messages': [],
|
||||
'worker_id': worker_id,
|
||||
'prompt_id': job_data.get('prompt_id', 'unknown') if job_data else 'unknown'
|
||||
})
|
||||
|
||||
finally:
|
||||
logging.info(f"Worker {worker_id} exiting")
|
||||
142
main.py
142
main.py
@ -12,9 +12,6 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -175,16 +172,22 @@ if 'torch' in sys.modules:
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
# Import modules needed for server operation
|
||||
# GPU initialization happens lazily when GPU functions are called
|
||||
# In subprocess mode, main process won't call GPU functions - workers will
|
||||
if __name__ == "__main__":
|
||||
import execution
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
if args.use_subprocess_workers:
|
||||
return
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
cuda_malloc_warning = False
|
||||
@ -196,84 +199,6 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
while True:
|
||||
timeout = 1000.0
|
||||
if need_gc:
|
||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||
|
||||
queue_item = q.get(timeout=timeout)
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
server_instance.last_prompt_id = prompt_id
|
||||
|
||||
sensitive = item[5]
|
||||
extra_data = item[3].copy()
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
# Log Time in a more readable way after 10 minutes
|
||||
if execution_time > 600:
|
||||
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
||||
logging.info(f"Prompt executed in {execution_time}")
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
if flags.get("unload_models", free_memory):
|
||||
comfy.model_management.unload_all_models()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if free_memory:
|
||||
e.reset()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
addresses = []
|
||||
for addr in address.split(","):
|
||||
@ -282,37 +207,6 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Only send old method if client doesn't support preview metadata
|
||||
if not feature_flags.supports_feature(
|
||||
server_instance.sockets_metadata,
|
||||
server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
server_instance.send_sync(
|
||||
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
|
||||
preview_image,
|
||||
server_instance.client_id,
|
||||
)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def cleanup_temp():
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
@ -357,20 +251,16 @@ def start_comfyui(asyncio_loop=None):
|
||||
if args.enable_manager and not args.disable_manager_ui:
|
||||
comfyui_manager.start()
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
))
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
from comfy.execution_core import create_worker, prompt_worker
|
||||
worker = create_worker(prompt_server)
|
||||
node_count = asyncio_loop.run_until_complete(worker.initialize())
|
||||
logging.info(f"Loaded {node_count} node types")
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, worker), name="PromptWorker").start()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -51,7 +51,8 @@ def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
from comfy.execution_core import interrupt_processing as core_interrupt
|
||||
core_interrupt(value)
|
||||
|
||||
MAX_RESOLUTION=16384
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user