From d9b4607c368adf8b98d8610c157bfbbf1da5f7be Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 7 Feb 2024 15:18:13 -0800 Subject: [PATCH] Add locks to model_management to prevent multiple copies of the models from being loaded at the same time --- comfy/client/embedded_comfy_client.py | 11 +- comfy/model_management.py | 239 +++++++++++----------- tests/distributed/test_embedded_client.py | 9 + 3 files changed, 142 insertions(+), 117 deletions(-) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 420e038a6..b760fdbf0 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -47,9 +47,10 @@ class EmbeddedComfyClient: ``` """ - def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None): + def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None, + max_workers: int = 1): self._server_stub = ServerStub() - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor = ThreadPoolExecutor(max_workers=max_workers) self._loop = loop or asyncio.get_event_loop() self._configuration = configuration # we don't want to import the executor yet @@ -71,6 +72,10 @@ class EmbeddedComfyClient: except: pass + # wait until the queue is done + while self._executor._work_queue.qsize() > 0: + await asyncio.sleep(0.1) + await self._loop.run_in_executor(self._executor, cleanup) self._executor.shutdown(wait=True) @@ -105,6 +110,6 @@ class EmbeddedComfyClient: if self._prompt_executor.success: return self._prompt_executor.outputs_ui else: - raise RuntimeError("\n".join(self._prompt_executor.status_messages)) + raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) return await self._loop.run_in_executor(self._executor, execute_prompt) diff --git a/comfy/model_management.py b/comfy/model_management.py index 913ecc14c..20d3ed705 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,10 +2,13 @@ import psutil from enum import Enum from .cli_args import args from . import utils +from threading import RLock import torch import sys +model_management_lock = RLock() + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -343,113 +346,118 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) def unload_model_clones(model): - to_unload = [] - for i in range(len(current_loaded_models)): - if model.is_clone(current_loaded_models[i].model): - to_unload = [i] + to_unload + with model_management_lock: + to_unload = [] + for i in range(len(current_loaded_models)): + if model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload - for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + for i in to_unload: + print("unload clone", i) + current_loaded_models.pop(i).model_unload() def free_memory(memory_required, device, keep_loaded=[]): - unloaded_model = False - for i in range(len(current_loaded_models) -1, -1, -1): - if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: - break - shift_model = current_loaded_models[i] - if shift_model.device == device: - if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + with model_management_lock: + unloaded_model = False + for i in range(len(current_loaded_models) -1, -1, -1): + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + shift_model = current_loaded_models[i] + if shift_model.device == device: + if shift_model not in keep_loaded: + m = current_loaded_models.pop(i) + m.model_unload() + del m + unloaded_model = True - if unloaded_model: - soft_empty_cache() - else: - if vram_state != VRAMState.HIGH_VRAM: - mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) - if mem_free_torch > mem_free_total * 0.25: - soft_empty_cache() + if unloaded_model: + soft_empty_cache() + else: + if vram_state != VRAMState.HIGH_VRAM: + mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) + if mem_free_torch > mem_free_total * 0.25: + soft_empty_cache() def load_models_gpu(models, memory_required=0): global vram_state - inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required) + with model_management_lock: + inference_memory = minimum_inference_memory() + extra_mem = max(inference_memory, memory_required) - models_to_load = [] - models_already_loaded = [] - for x in models: - loaded_model = LoadedModel(x) + models_to_load = [] + models_already_loaded = [] + for x in models: + loaded_model = LoadedModel(x) - if loaded_model in current_loaded_models: - index = current_loaded_models.index(loaded_model) - current_loaded_models.insert(0, current_loaded_models.pop(index)) - models_already_loaded.append(loaded_model) - else: - if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") - models_to_load.append(loaded_model) - - if len(models_to_load) == 0: - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - free_memory(extra_mem, d, models_already_loaded) - return - - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") - - total_memory_required = {} - for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) - - for loaded_model in models_to_load: - model = loaded_model.model - torch_dev = model.load_device - if is_device_cpu(torch_dev): - vram_set_state = VRAMState.DISABLED - else: - vram_set_state = vram_state - lowvram_model_memory = 0 - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_size = loaded_model.model_memory_required(torch_dev) - current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) - if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary - vram_set_state = VRAMState.LOW_VRAM + if loaded_model in current_loaded_models: + index = current_loaded_models.index(loaded_model) + current_loaded_models.insert(0, current_loaded_models.pop(index)) + models_already_loaded.append(loaded_model) else: - lowvram_model_memory = 0 + if hasattr(x, "model"): + print(f"Requested to load {x.model.__class__.__name__}") + models_to_load.append(loaded_model) - if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 64 * 1024 * 1024 + if len(models_to_load) == 0: + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_memory(extra_mem, d, models_already_loaded) + return - cur_loaded_model = loaded_model.model_load(lowvram_model_memory) - current_loaded_models.insert(0, loaded_model) - return + print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + + total_memory_required = {} + for loaded_model in models_to_load: + unload_model_clones(loaded_model.model) + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + + for loaded_model in models_to_load: + model = loaded_model.model + torch_dev = model.load_device + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + lowvram_model_memory = 0 + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = loaded_model.model_memory_required(torch_dev) + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + else: + lowvram_model_memory = 0 + + if vram_set_state == VRAMState.NO_VRAM: + lowvram_model_memory = 64 * 1024 * 1024 + + cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + current_loaded_models.insert(0, loaded_model) + return def load_model_gpu(model): - return load_models_gpu([model]) + with model_management_lock: + return load_models_gpu([model]) def cleanup_models(): - to_delete = [] - for i in range(len(current_loaded_models)): - if sys.getrefcount(current_loaded_models[i].model) <= 2: - to_delete = [i] + to_delete + with model_management_lock: + to_delete = [] + for i in range(len(current_loaded_models)): + if sys.getrefcount(current_loaded_models[i].model) <= 2: + to_delete = [i] + to_delete - for i in to_delete: - x = current_loaded_models.pop(i) - x.model_unload() - del x + for i in to_delete: + x = current_loaded_models.pop(i) + x.model_unload() + del x def dtype_size(dtype): dtype_size = 4 @@ -593,26 +601,27 @@ def device_supports_non_blocking(device): return True def cast_to_device(tensor, device, dtype, copy=False): - device_supports_cast = False - if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: - device_supports_cast = True - elif tensor.dtype == torch.bfloat16: - if hasattr(device, 'type') and device.type.startswith("cuda"): - device_supports_cast = True - elif is_intel_xpu(): + with model_management_lock: + device_supports_cast = False + if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: device_supports_cast = True + elif tensor.dtype == torch.bfloat16: + if hasattr(device, 'type') and device.type.startswith("cuda"): + device_supports_cast = True + elif is_intel_xpu(): + device_supports_cast = True - non_blocking = device_supports_non_blocking(device) + non_blocking = device_supports_non_blocking(device) - if device_supports_cast: - if copy: - if tensor.device == device: - return tensor.to(dtype, copy=copy, non_blocking=non_blocking) - return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) + if device_supports_cast: + if copy: + if tensor.device == device: + return tensor.to(dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) + else: + return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) - else: - return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled @@ -759,18 +768,20 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True def soft_empty_cache(force=False): - global cpu_state - if cpu_state == CPUState.MPS: - torch.mps.empty_cache() - elif is_intel_xpu(): - torch.xpu.empty_cache() - elif torch.cuda.is_available(): - if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + with model_management_lock: + global cpu_state + if cpu_state == CPUState.MPS: + torch.mps.empty_cache() + elif is_intel_xpu(): + torch.xpu.empty_cache() + elif torch.cuda.is_available(): + if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + with model_management_lock: + free_memory(1e30, get_torch_device()) def resolve_lowvram_weight(weight, model, key): #TODO: remove diff --git a/tests/distributed/test_embedded_client.py b/tests/distributed/test_embedded_client.py index b48d9b869..c4eeb3d15 100644 --- a/tests/distributed/test_embedded_client.py +++ b/tests/distributed/test_embedded_client.py @@ -1,3 +1,5 @@ +import asyncio + import pytest import torch @@ -31,3 +33,10 @@ async def test_embedded_comfy(): prompt = sdxl_workflow_with_refiner("test") outputs = await client.queue_prompt(prompt) assert outputs["13"]["images"][0]["abs_path"] is not None + +@pytest.mark.asyncio +async def test_multithreaded_comfy(): + async with EmbeddedComfyClient(max_workers=2) as client: + prompt = sdxl_workflow_with_refiner("test") + outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)]) + assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter) \ No newline at end of file