Add locks to model_management to prevent multiple copies of the models from being loaded at the same time

This commit is contained in:
doctorpangloss 2024-02-07 15:18:13 -08:00
parent 8e9052c843
commit d9b4607c36
3 changed files with 142 additions and 117 deletions

View File

@ -47,9 +47,10 @@ class EmbeddedComfyClient:
``` ```
""" """
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None): def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
max_workers: int = 1):
self._server_stub = ServerStub() self._server_stub = ServerStub()
self._executor = ThreadPoolExecutor(max_workers=1) self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._configuration = configuration self._configuration = configuration
# we don't want to import the executor yet # we don't want to import the executor yet
@ -71,6 +72,10 @@ class EmbeddedComfyClient:
except: except:
pass pass
# wait until the queue is done
while self._executor._work_queue.qsize() > 0:
await asyncio.sleep(0.1)
await self._loop.run_in_executor(self._executor, cleanup) await self._loop.run_in_executor(self._executor, cleanup)
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
@ -105,6 +110,6 @@ class EmbeddedComfyClient:
if self._prompt_executor.success: if self._prompt_executor.success:
return self._prompt_executor.outputs_ui return self._prompt_executor.outputs_ui
else: else:
raise RuntimeError("\n".join(self._prompt_executor.status_messages)) raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
return await self._loop.run_in_executor(self._executor, execute_prompt) return await self._loop.run_in_executor(self._executor, execute_prompt)

View File

@ -2,10 +2,13 @@ import psutil
from enum import Enum from enum import Enum
from .cli_args import args from .cli_args import args
from . import utils from . import utils
from threading import RLock
import torch import torch
import sys import sys
model_management_lock = RLock()
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -343,6 +346,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model): def unload_model_clones(model):
with model_management_lock:
to_unload = [] to_unload = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model): if model.is_clone(current_loaded_models[i].model):
@ -353,6 +357,7 @@ def unload_model_clones(model):
current_loaded_models.pop(i).model_unload() current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock:
unloaded_model = False unloaded_model = False
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
@ -377,6 +382,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
def load_models_gpu(models, memory_required=0): def load_models_gpu(models, memory_required=0):
global vram_state global vram_state
with model_management_lock:
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
@ -438,9 +444,11 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model): def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(): def cleanup_models():
with model_management_lock:
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
@ -593,6 +601,7 @@ def device_supports_non_blocking(device):
return True return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
with model_management_lock:
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
device_supports_cast = True device_supports_cast = True
@ -759,6 +768,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True return True
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
with model_management_lock:
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
@ -770,6 +780,7 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def unload_all_models(): def unload_all_models():
with model_management_lock:
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())

View File

@ -1,3 +1,5 @@
import asyncio
import pytest import pytest
import torch import torch
@ -31,3 +33,10 @@ async def test_embedded_comfy():
prompt = sdxl_workflow_with_refiner("test") prompt = sdxl_workflow_with_refiner("test")
outputs = await client.queue_prompt(prompt) outputs = await client.queue_prompt(prompt)
assert outputs["13"]["images"][0]["abs_path"] is not None assert outputs["13"]["images"][0]["abs_path"] is not None
@pytest.mark.asyncio
async def test_multithreaded_comfy():
async with EmbeddedComfyClient(max_workers=2) as client:
prompt = sdxl_workflow_with_refiner("test")
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)