mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 12:20:16 +08:00
Add locks to model_management to prevent multiple copies of the models from being loaded at the same time
This commit is contained in:
parent
8e9052c843
commit
d9b4607c36
@ -47,9 +47,10 @@ class EmbeddedComfyClient:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None):
|
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
|
||||||
|
max_workers: int = 1):
|
||||||
self._server_stub = ServerStub()
|
self._server_stub = ServerStub()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
# we don't want to import the executor yet
|
# we don't want to import the executor yet
|
||||||
@ -71,6 +72,10 @@ class EmbeddedComfyClient:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# wait until the queue is done
|
||||||
|
while self._executor._work_queue.qsize() > 0:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, cleanup)
|
await self._loop.run_in_executor(self._executor, cleanup)
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
@ -105,6 +110,6 @@ class EmbeddedComfyClient:
|
|||||||
if self._prompt_executor.success:
|
if self._prompt_executor.success:
|
||||||
return self._prompt_executor.outputs_ui
|
return self._prompt_executor.outputs_ui
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("\n".join(self._prompt_executor.status_messages))
|
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||||
|
|
||||||
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
||||||
|
|||||||
@ -2,10 +2,13 @@ import psutil
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
model_management_lock = RLock()
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@ -343,113 +346,118 @@ def minimum_inference_memory():
|
|||||||
return (1024 * 1024 * 1024)
|
return (1024 * 1024 * 1024)
|
||||||
|
|
||||||
def unload_model_clones(model):
|
def unload_model_clones(model):
|
||||||
to_unload = []
|
with model_management_lock:
|
||||||
for i in range(len(current_loaded_models)):
|
to_unload = []
|
||||||
if model.is_clone(current_loaded_models[i].model):
|
for i in range(len(current_loaded_models)):
|
||||||
to_unload = [i] + to_unload
|
if model.is_clone(current_loaded_models[i].model):
|
||||||
|
to_unload = [i] + to_unload
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
print("unload clone", i)
|
print("unload clone", i)
|
||||||
current_loaded_models.pop(i).model_unload()
|
current_loaded_models.pop(i).model_unload()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
unloaded_model = False
|
with model_management_lock:
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
unloaded_model = False
|
||||||
if not DISABLE_SMART_MEMORY:
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
if get_free_memory(device) > memory_required:
|
if not DISABLE_SMART_MEMORY:
|
||||||
break
|
if get_free_memory(device) > memory_required:
|
||||||
shift_model = current_loaded_models[i]
|
break
|
||||||
if shift_model.device == device:
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model not in keep_loaded:
|
if shift_model.device == device:
|
||||||
m = current_loaded_models.pop(i)
|
if shift_model not in keep_loaded:
|
||||||
m.model_unload()
|
m = current_loaded_models.pop(i)
|
||||||
del m
|
m.model_unload()
|
||||||
unloaded_model = True
|
del m
|
||||||
|
unloaded_model = True
|
||||||
|
|
||||||
if unloaded_model:
|
if unloaded_model:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
else:
|
else:
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0):
|
def load_models_gpu(models, memory_required=0):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
with model_management_lock:
|
||||||
extra_mem = max(inference_memory, memory_required)
|
inference_memory = minimum_inference_memory()
|
||||||
|
extra_mem = max(inference_memory, memory_required)
|
||||||
|
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
models_already_loaded = []
|
models_already_loaded = []
|
||||||
for x in models:
|
for x in models:
|
||||||
loaded_model = LoadedModel(x)
|
loaded_model = LoadedModel(x)
|
||||||
|
|
||||||
if loaded_model in current_loaded_models:
|
if loaded_model in current_loaded_models:
|
||||||
index = current_loaded_models.index(loaded_model)
|
index = current_loaded_models.index(loaded_model)
|
||||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||||
models_already_loaded.append(loaded_model)
|
models_already_loaded.append(loaded_model)
|
||||||
else:
|
|
||||||
if hasattr(x, "model"):
|
|
||||||
print(f"Requested to load {x.model.__class__.__name__}")
|
|
||||||
models_to_load.append(loaded_model)
|
|
||||||
|
|
||||||
if len(models_to_load) == 0:
|
|
||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
|
||||||
for d in devs:
|
|
||||||
if d != torch.device("cpu"):
|
|
||||||
free_memory(extra_mem, d, models_already_loaded)
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
|
||||||
|
|
||||||
total_memory_required = {}
|
|
||||||
for loaded_model in models_to_load:
|
|
||||||
unload_model_clones(loaded_model.model)
|
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
|
||||||
|
|
||||||
for device in total_memory_required:
|
|
||||||
if device != torch.device("cpu"):
|
|
||||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
|
||||||
model = loaded_model.model
|
|
||||||
torch_dev = model.load_device
|
|
||||||
if is_device_cpu(torch_dev):
|
|
||||||
vram_set_state = VRAMState.DISABLED
|
|
||||||
else:
|
|
||||||
vram_set_state = vram_state
|
|
||||||
lowvram_model_memory = 0
|
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
|
||||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
|
||||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
|
||||||
else:
|
else:
|
||||||
lowvram_model_memory = 0
|
if hasattr(x, "model"):
|
||||||
|
print(f"Requested to load {x.model.__class__.__name__}")
|
||||||
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if len(models_to_load) == 0:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
|
for d in devs:
|
||||||
|
if d != torch.device("cpu"):
|
||||||
|
free_memory(extra_mem, d, models_already_loaded)
|
||||||
|
return
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||||
current_loaded_models.insert(0, loaded_model)
|
|
||||||
return
|
total_memory_required = {}
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
unload_model_clones(loaded_model.model)
|
||||||
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
model = loaded_model.model
|
||||||
|
torch_dev = model.load_device
|
||||||
|
if is_device_cpu(torch_dev):
|
||||||
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
else:
|
||||||
|
vram_set_state = vram_state
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
|
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||||
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
else:
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||||
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
with model_management_lock:
|
||||||
|
return load_models_gpu([model])
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
to_delete = []
|
with model_management_lock:
|
||||||
for i in range(len(current_loaded_models)):
|
to_delete = []
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
for i in range(len(current_loaded_models)):
|
||||||
to_delete = [i] + to_delete
|
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||||
|
to_delete = [i] + to_delete
|
||||||
|
|
||||||
for i in to_delete:
|
for i in to_delete:
|
||||||
x = current_loaded_models.pop(i)
|
x = current_loaded_models.pop(i)
|
||||||
x.model_unload()
|
x.model_unload()
|
||||||
del x
|
del x
|
||||||
|
|
||||||
def dtype_size(dtype):
|
def dtype_size(dtype):
|
||||||
dtype_size = 4
|
dtype_size = 4
|
||||||
@ -593,26 +601,27 @@ def device_supports_non_blocking(device):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
with model_management_lock:
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
device_supports_cast = False
|
||||||
device_supports_cast = True
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
elif tensor.dtype == torch.bfloat16:
|
|
||||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
|
||||||
device_supports_cast = True
|
|
||||||
elif is_intel_xpu():
|
|
||||||
device_supports_cast = True
|
device_supports_cast = True
|
||||||
|
elif tensor.dtype == torch.bfloat16:
|
||||||
|
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||||
|
device_supports_cast = True
|
||||||
|
elif is_intel_xpu():
|
||||||
|
device_supports_cast = True
|
||||||
|
|
||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
|
|
||||||
if device_supports_cast:
|
if device_supports_cast:
|
||||||
if copy:
|
if copy:
|
||||||
if tensor.device == device:
|
if tensor.device == device:
|
||||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
|
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||||
else:
|
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@ -759,18 +768,20 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
with model_management_lock:
|
||||||
if cpu_state == CPUState.MPS:
|
global cpu_state
|
||||||
torch.mps.empty_cache()
|
if cpu_state == CPUState.MPS:
|
||||||
elif is_intel_xpu():
|
torch.mps.empty_cache()
|
||||||
torch.xpu.empty_cache()
|
elif is_intel_xpu():
|
||||||
elif torch.cuda.is_available():
|
torch.xpu.empty_cache()
|
||||||
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
elif torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
with model_management_lock:
|
||||||
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,3 +33,10 @@ async def test_embedded_comfy():
|
|||||||
prompt = sdxl_workflow_with_refiner("test")
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
assert outputs["13"]["images"][0]["abs_path"] is not None
|
assert outputs["13"]["images"][0]["abs_path"] is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multithreaded_comfy():
|
||||||
|
async with EmbeddedComfyClient(max_workers=2) as client:
|
||||||
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
|
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
|
||||||
|
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
||||||
Loading…
Reference in New Issue
Block a user