diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..20fe3ead0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -306,7 +306,10 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + # assign=True will reuse the tensor storage in state dict, this will avoid copy and saving CPU memory + # when loading large models with mmap. + delay_copy_with_assign = utils.MMAP_TORCH_FILES or not utils.DISABLE_MMAP + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=delay_copy_with_assign) if len(m) > 0: logging.warning("unet missing: {}".format(m))