mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Merge pull request #4 from xmarre/codex/wsl-model-load-guard
Guard WSL CUDA sync during model load
This commit is contained in:
commit
fa0eaccfcc
@ -190,6 +190,14 @@ def is_wsl():
|
||||
return True
|
||||
return False
|
||||
|
||||
_WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = False
|
||||
|
||||
def wsl_skip_nonforced_soft_empty_cache():
|
||||
return is_wsl() and os.getenv("COMFYUI_WSL_SOFT_EMPTY_CACHE", "0") != "1"
|
||||
|
||||
def wsl_skip_model_load_synchronize():
|
||||
return is_wsl() and os.getenv("COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE", "0") != "1"
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@ -917,7 +925,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
|
||||
model_name = model.model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__
|
||||
logging.info(
|
||||
f"Loading model {model_name} start: device={torch_dev} "
|
||||
f"vram_state={vram_set_state.name} lowvram_model_memory={lowvram_model_memory} "
|
||||
f"force_full_load={force_full_load} force_patch_weights={force_patch_weights}"
|
||||
)
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
logging.info(f"Loading model {model_name} complete")
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
@ -1955,6 +1970,12 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if wsl_skip_nonforced_soft_empty_cache() and not force:
|
||||
global _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED
|
||||
if not _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED:
|
||||
logging.info("Skipping non-forced CUDA soft_empty_cache on WSL; set COMFYUI_WSL_SOFT_EMPTY_CACHE=1 to re-enable.")
|
||||
_WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = True
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@ -42,6 +42,8 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
_WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = False
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
@ -1005,6 +1007,11 @@ class ModelPatcher:
|
||||
mem_counter += move_weight_functions(m, device_to)
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
skip_wsl_load_sync = comfy.model_management.is_device_cuda(device_to) and comfy.model_management.wsl_skip_model_load_synchronize()
|
||||
global _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED
|
||||
if skip_wsl_load_sync and len(load_completely) > 0 and not _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED:
|
||||
logging.info("Skipping per-module CUDA synchronize during model load on WSL; set COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE=1 to re-enable.")
|
||||
_WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = True
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
@ -1019,7 +1026,7 @@ class ModelPatcher:
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
if comfy.model_management.is_device_cuda(device_to) and not skip_wsl_load_sync:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user