mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 22:30:19 +08:00
Add locks to model_management to prevent multiple copies of the models from being loaded at the same time
This commit is contained in:
parent
8e9052c843
commit
d9b4607c36
@ -47,9 +47,10 @@ class EmbeddedComfyClient:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None):
|
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
|
||||||
|
max_workers: int = 1):
|
||||||
self._server_stub = ServerStub()
|
self._server_stub = ServerStub()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
# we don't want to import the executor yet
|
# we don't want to import the executor yet
|
||||||
@ -71,6 +72,10 @@ class EmbeddedComfyClient:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# wait until the queue is done
|
||||||
|
while self._executor._work_queue.qsize() > 0:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, cleanup)
|
await self._loop.run_in_executor(self._executor, cleanup)
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
@ -105,6 +110,6 @@ class EmbeddedComfyClient:
|
|||||||
if self._prompt_executor.success:
|
if self._prompt_executor.success:
|
||||||
return self._prompt_executor.outputs_ui
|
return self._prompt_executor.outputs_ui
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("\n".join(self._prompt_executor.status_messages))
|
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||||
|
|
||||||
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
||||||
|
|||||||
@ -2,10 +2,13 @@ import psutil
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
model_management_lock = RLock()
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@ -343,6 +346,7 @@ def minimum_inference_memory():
|
|||||||
return (1024 * 1024 * 1024)
|
return (1024 * 1024 * 1024)
|
||||||
|
|
||||||
def unload_model_clones(model):
|
def unload_model_clones(model):
|
||||||
|
with model_management_lock:
|
||||||
to_unload = []
|
to_unload = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if model.is_clone(current_loaded_models[i].model):
|
if model.is_clone(current_loaded_models[i].model):
|
||||||
@ -353,6 +357,7 @@ def unload_model_clones(model):
|
|||||||
current_loaded_models.pop(i).model_unload()
|
current_loaded_models.pop(i).model_unload()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
|
with model_management_lock:
|
||||||
unloaded_model = False
|
unloaded_model = False
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
@ -377,6 +382,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
def load_models_gpu(models, memory_required=0):
|
def load_models_gpu(models, memory_required=0):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
|
with model_management_lock:
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required)
|
extra_mem = max(inference_memory, memory_required)
|
||||||
|
|
||||||
@ -438,9 +444,11 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
|
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
|
with model_management_lock:
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
|
with model_management_lock:
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||||
@ -593,6 +601,7 @@ def device_supports_non_blocking(device):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
|
with model_management_lock:
|
||||||
device_supports_cast = False
|
device_supports_cast = False
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
device_supports_cast = True
|
device_supports_cast = True
|
||||||
@ -759,6 +768,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
with model_management_lock:
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
@ -770,6 +780,7 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
|
with model_management_lock:
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,3 +33,10 @@ async def test_embedded_comfy():
|
|||||||
prompt = sdxl_workflow_with_refiner("test")
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
assert outputs["13"]["images"][0]["abs_path"] is not None
|
assert outputs["13"]["images"][0]["abs_path"] is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multithreaded_comfy():
|
||||||
|
async with EmbeddedComfyClient(max_workers=2) as client:
|
||||||
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
|
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
|
||||||
|
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
||||||
Loading…
Reference in New Issue
Block a user