diff --git a/comfy/model_management.py b/comfy/model_management.py index b01c4d7fa..83cb5d277 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -190,6 +190,14 @@ def is_wsl(): return True return False +_WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = False + +def wsl_skip_nonforced_soft_empty_cache(): + return is_wsl() and os.getenv("COMFYUI_WSL_SOFT_EMPTY_CACHE", "0") != "1" + +def wsl_skip_model_load_synchronize(): + return is_wsl() and os.getenv("COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE", "0") != "1" + def get_torch_device(): global directml_enabled global cpu_state @@ -917,7 +925,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 + model_name = model.model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ + logging.info( + f"Loading model {model_name} start: device={torch_dev} " + f"vram_state={vram_set_state.name} lowvram_model_memory={lowvram_model_memory} " + f"force_full_load={force_full_load} force_patch_weights={force_patch_weights}" + ) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) + logging.info(f"Loading model {model_name} complete") current_loaded_models.insert(0, loaded_model) return @@ -1955,6 +1970,12 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): + if wsl_skip_nonforced_soft_empty_cache() and not force: + global _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED + if not _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED: + logging.info("Skipping non-forced CUDA soft_empty_cache on WSL; set COMFYUI_WSL_SOFT_EMPTY_CACHE=1 to re-enable.") + _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = True + return torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00a15fa63..d7a2fb704 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -42,6 +42,8 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP import comfy_aimdo.model_vbar +_WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = False + def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -1005,6 +1007,11 @@ class ModelPatcher: mem_counter += move_weight_functions(m, device_to) load_completely.sort(reverse=True) + skip_wsl_load_sync = comfy.model_management.is_device_cuda(device_to) and comfy.model_management.wsl_skip_model_load_synchronize() + global _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED + if skip_wsl_load_sync and len(load_completely) > 0 and not _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED: + logging.info("Skipping per-module CUDA synchronize during model load on WSL; set COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE=1 to re-enable.") + _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = True for x in load_completely: n = x[1] m = x[2] @@ -1019,7 +1026,7 @@ class ModelPatcher: key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) - if comfy.model_management.is_device_cuda(device_to): + if comfy.model_management.is_device_cuda(device_to) and not skip_wsl_load_sync: torch.cuda.synchronize() logging.debug("lowvram: loaded module regularly {} {}".format(n, m))