mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 09:19:46 +08:00
Guard WSL CUDA sync during model load
This commit is contained in:
parent
69cbd50aa6
commit
63e08a02fd
@ -190,6 +190,14 @@ def is_wsl():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
_WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = False
|
||||||
|
|
||||||
|
def wsl_skip_nonforced_soft_empty_cache():
|
||||||
|
return is_wsl() and os.getenv("COMFYUI_WSL_SOFT_EMPTY_CACHE", "0") != "1"
|
||||||
|
|
||||||
|
def wsl_skip_model_load_synchronize():
|
||||||
|
return is_wsl() and os.getenv("COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE", "0") != "1"
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -917,7 +925,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 0.1
|
lowvram_model_memory = 0.1
|
||||||
|
|
||||||
|
model_name = model.model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__
|
||||||
|
logging.info(
|
||||||
|
f"Loading model {model_name} start: device={torch_dev} "
|
||||||
|
f"vram_state={vram_set_state.name} lowvram_model_memory={lowvram_model_memory} "
|
||||||
|
f"force_full_load={force_full_load} force_patch_weights={force_patch_weights}"
|
||||||
|
)
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
logging.info(f"Loading model {model_name} complete")
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1955,6 +1970,12 @@ def soft_empty_cache(force=False):
|
|||||||
elif is_mlu():
|
elif is_mlu():
|
||||||
torch.mlu.empty_cache()
|
torch.mlu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
|
if wsl_skip_nonforced_soft_empty_cache() and not force:
|
||||||
|
global _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED
|
||||||
|
if not _WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED:
|
||||||
|
logging.info("Skipping non-forced CUDA soft_empty_cache on WSL; set COMFYUI_WSL_SOFT_EMPTY_CACHE=1 to re-enable.")
|
||||||
|
_WSL_SOFT_EMPTY_CACHE_SKIP_LOGGED = True
|
||||||
|
return
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|||||||
@ -42,6 +42,8 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
|||||||
|
|
||||||
import comfy_aimdo.model_vbar
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
|
_WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = False
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
|
|
||||||
@ -1005,6 +1007,11 @@ class ModelPatcher:
|
|||||||
mem_counter += move_weight_functions(m, device_to)
|
mem_counter += move_weight_functions(m, device_to)
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
|
skip_wsl_load_sync = comfy.model_management.is_device_cuda(device_to) and comfy.model_management.wsl_skip_model_load_synchronize()
|
||||||
|
global _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED
|
||||||
|
if skip_wsl_load_sync and len(load_completely) > 0 and not _WSL_MODEL_LOAD_SYNC_SKIP_LOGGED:
|
||||||
|
logging.info("Skipping per-module CUDA synchronize during model load on WSL; set COMFYUI_WSL_MODEL_LOAD_SYNCHRONIZE=1 to re-enable.")
|
||||||
|
_WSL_MODEL_LOAD_SYNC_SKIP_LOGGED = True
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
m = x[2]
|
m = x[2]
|
||||||
@ -1019,7 +1026,7 @@ class ModelPatcher:
|
|||||||
key = key_param_name_to_key(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
if comfy.model_management.is_device_cuda(device_to):
|
if comfy.model_management.is_device_cuda(device_to) and not skip_wsl_load_sync:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user