diff --git a/comfy/model_base.py b/comfy/model_base.py index 6c8ee69b4..75d469221 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ import math from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -291,18 +292,19 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): - import pdb; pdb.set_trace() to_load = {} keys = list(sd.keys()) for k in keys: if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) - logging.info(f"load model weights start, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") to_load = self.model_config.process_unet_state_dict(to_load) - logging.info(f"load model {self.model_config} weights process end, keys {keys}") + logging.info(f"load model {self.model_config} weights process end") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) - logging.info(f"load model {self.model_config} weights end, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m))